Skip to content

Conversation

@djmmoss
Copy link
Collaborator

@djmmoss djmmoss commented Nov 13, 2025

This PR adds sink support to standard flashinfer decode

pytest tests/attention/test_decode_sink_attention.py -xs
============================================================================================================================= test session starts ==============================================================================================================================
platform linux -- Python 3.10.12, pytest-9.0.1, pluggy-1.6.0
rootdir: /home/scratch.dmoss_gpu_1/repos/flashinfer
configfile: pytest.ini
collected 110 items

tests/attention/test_decode_sink_attention.py ..............................................................................................................

============================================================================================================================= 110 passed in 8.82s ==============================================================================================================================

Summary by CodeRabbit

  • New Features

    • Added optional auxiliary "sink" (maybe_s_aux/sinks) tensor support across prefill and decode flows, integrating sink values into attention softmax and threading auxiliary buffers through single- and batch-decode (including paged KV) paths.
    • Added input normalization helper to produce a kernel-friendly sink buffer format.
  • Tests

    • Added comprehensive tests validating sink-attention decoding across batch/single flows, layouts, GQA scenarios, paged KV, and numerical tolerances.

djmmoss and others added 2 commits November 13, 2025 11:23
Co-authored-by: Nitin Gupta <ngupta@groq.com>
Co-authored-by: Duncan Moss <djmmoss@gmail.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 13, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds optional per-head sink auxiliary (maybe_s_aux / sinks) and threads it through Python APIs, JIT module generation, utils normalization, prefill/decode param structs, and CUDA kernels; integrates s_aux into softmax denominator when present and adds tests covering batch and single decode sink-attention paths.

Changes

Cohort / File(s) Summary
Python decode & prefill API
flashinfer/decode.py, flashinfer/prefill.py
Added optional maybe_s_aux: Optional[torch.Tensor] / sinks: Optional[torch.Tensor] parameters to public run/single-decode/prefill entry points and forwarded them through JIT and non-JIT run paths.
JIT module generation
flashinfer/jit/attention/modules.py
Added "maybe_s_aux" to additional_tensor_names and its dtype to additional_tensor_dtypes for single-decode, single-prefill (fa2), and batch-decode module generation.
Utility
flashinfer/utils.py
Added _get_sink_buf(sinks: Optional[torch.Tensor]) -> Optional[torch.Tensor] to normalize optional sink tensors to float32 contiguous buffers (or return None).
Decode/prefill parameter structs
include/flashinfer/attention/default_decode_params.cuh, include/flashinfer/attention/default_prefill_params.cuh
Added float* maybe_s_aux members to Single/Batch Decode and Prefill parameter structs and updated constructors to initialize/accept the new pointer (default nullptr).
CUDA kernels / variants
include/flashinfer/attention/decode.cuh, include/flashinfer/attention/variants.cuh
When use_softmax and maybe_s_aux present, read per-head s_aux and add exp2((s_aux - max)*LOG2_E) into the softmax denominator after tile processing / before final output transform; changes guarded to preserve existing behavior when absent.
Tests
tests/attention/test_decode_sink_attention.py
New tests and reference helper for sink-attention: JIT warmup fixture and parametrized tests covering batch/single decode, GQA, layouts, page sizes, and numeric comparisons against a reference implementation.

Sequence Diagram

sequenceDiagram
    autonumber
    actor User
    participant PyAPI as Python API (decode/prefill)
    participant Utils as _get_sink_buf
    participant JIT as JIT Module
    participant Kernel as CUDA Kernel

    User->>PyAPI: run_* (q, ..., maybe_s_aux / sinks)
    PyAPI->>Utils: _get_sink_buf(maybe_s_aux) 
    Utils-->>PyAPI: sink_buf or None
    alt JIT path
        PyAPI->>JIT: invoke module with maybe_s_aux tensor
        JIT->>Kernel: kernel invoked (has maybe_s_aux)
    else non-JIT path
        PyAPI->>Kernel: call kernel with sink_ptr (from _get_sink_buf)
    end
    Kernel->>Kernel: if use_softmax && maybe_s_aux:\n compute exp2((s_aux - max)*LOG2_E) and add to denom
    Kernel-->>PyAPI: attention outputs
    PyAPI-->>User: return result
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Focus areas:
    • Verify consistent threading of maybe_s_aux/sinks across Python → JIT → kernel for both JIT and non-JIT branches.
    • Inspect CUDA additions for correct indexing, guarding, and numerical stability of the exp2-based denominator update.
    • Confirm constructors and default nullptr handling in prefills/decode params.
    • Review new tests for coverage, tolerances, and correct reference comparisons.

Possibly related PRs

Suggested reviewers

  • yzh119
  • cyx-6
  • jiahanc
  • Anerudhan
  • joker-eph

Poem

🐰 I hopped a tiny buffer through the trees,
A whisper of sinks carried on the breeze.
From Python paths to CUDA's glowing halls,
Softmax learns another voice that calls.
Tests clap their paws — a gentle, happy thud.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 46.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The PR description explains what is being added (sink support to flashinfer decode) and provides test results showing 110 passing tests, but lacks detail in the Description section and does not explicitly check the pre-commit checklist items. Expand the Description section with more details on what sink support enables and why it's needed. Consider explicitly confirming pre-commit checks and test status in the checklist.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: add sink to flashinfer decode' clearly and concisely summarizes the main change—adding sink support to the decode functionality.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining why this PR is needed, why this solution was chosen, and what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @djmmoss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates 'sink attention' into FlashInfer's decode functionality, a mechanism vital for managing context length in large language models by allowing a learnable sink token to influence the attention softmax. The changes span the Python API, JIT compilation modules, and underlying CUDA kernels, ensuring a cohesive implementation. A new utility function facilitates proper handling of sink tensors, and a comprehensive suite of unit tests has been added to validate the feature's correctness and robustness across various configurations, including grouped query attention.

Highlights

  • Sink Attention Feature: Implemented support for sink attention in FlashInfer's decode operations, allowing a learnable sink token to influence the softmax denominator.
  • CUDA Kernel Integration: Modified core CUDA kernels (SingleDecodeWithKVCacheKernel, BatchDecodeWithPagedKVCacheDevice) to incorporate the s_aux (sink) contribution to the softmax denominator during attention calculation.
  • Python API and JIT Module Extension: Extended Python decode functions (run_batch_decode, _fake_run_batch_decode) and JIT compilation modules to accept and process an optional sinks tensor, ensuring end-to-end integration.
  • Utility Function for Sink Tensors: Added a new utility function _get_sink_buf in flashinfer/utils.py to convert optional sink tensors to the required contiguous float32 format for CUDA kernels.
  • Comprehensive Testing: Introduced a new, dedicated test suite (test_decode_sink_attention.py) that includes a reference implementation, functional validation against the reference, tests for scenarios without sinks, and specific tests for grouped query attention (GQA) compatibility.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for sink attention in the flashinfer decode functionality. The changes are well-structured, touching upon Python bindings, JIT compilation, and the core C++ CUDA kernels. A comprehensive suite of tests has been added to validate the new feature. The implementation appears correct. I have one suggestion regarding the C++ kernel code to improve maintainability by reducing code duplication.

Comment on lines +358 to +365
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for adding the sink contribution is duplicated in BatchDecodeWithPagedKVCacheDevice (lines 600-607). To improve maintainability and reduce code duplication, consider extracting this block into a helper function.

Also, the constant LOG2_E is defined inline here and in the other location. It would be better to define it once at the top of the file in an anonymous namespace to avoid magic numbers and ensure consistency.

For example, you could add at the top of the file:

namespace flashinfer {

namespace { // anonymous namespace

static constexpr float LOG2_E = 1.4426950408889634f;  // log2(e)

template <typename AttentionVariant, typename State, typename Params>
__device__ __forceinline__ void AddSinkContribution(AttentionVariant variant, State& st,
                                                    const Params& params,
                                                    uint32_t qo_head_idx) {
  if constexpr (variant.use_softmax) {
    if (params.maybe_s_aux != nullptr) {
      float s_aux_val = params.maybe_s_aux[qo_head_idx];
      st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
    }
  }
}

} // anonymous namespace

// ... rest of the file

Then you could replace this block and the one in BatchDecodeWithPagedKVCacheDevice with a call to this helper function:

AddSinkContribution(variant, st_local, params, qo_head_idx);

@yzh119
Copy link
Collaborator

yzh119 commented Nov 13, 2025

Note: Adding attention sink support to the CUDA cores template (in decode.cuh) is a welcome addition for coverage. However, this template is scheduled for cleanup. Decode attention will default to use_tensor_cores=True, which uses the tensor core-based template defined in prefill.cuh.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/decode.py (2)

1260-1285: Validate and prepare sinks once (device + shape).

Before building run args, validate sinks length matches num_qo_heads and move to q.device. Prepare a buffer to reuse in both branches.

     if rope_theta is None:
         rope_theta = 1e4
+        # prepare sinks (optional)
+        sink_buf = None
+        if sinks is not None:
+            if sinks.dim() != 1 or sinks.numel() != q.shape[1]:
+                raise ValueError(f"sinks must be 1D with length num_qo_heads={q.shape[1]}, got {tuple(sinks.shape)}")
+            sink_buf = _get_sink_buf(sinks.to(q.device))

1291-1338: Remove sinks argument from tensor-core FA2 paged_run path only; non-tensor-core fix in review is incomplete.

The tensor-core issue is confirmed critical: FA2's paged_run kernel does not pass the sinks argument to the actual kernel, yet line 1337 passes _get_sink_buf(sinks). This will fail at runtime due to ABI mismatch.

However, the review's non-tensor-core fix is incomplete. The non-tensor-core decode module's run() method does accept sinks, but the review references a sink_buf variable that doesn't exist in the codebase. Line 1371 currently calls _get_sink_buf(sinks) inline. To apply the review's intent, either:

  1. Pre-compute sink_buf = _get_sink_buf(sinks) before the if/else block, then reference it, or
  2. Keep the inline _get_sink_buf(sinks) call in non-tensor-core as-is

Remove _get_sink_buf(sinks) from line 1337 in the tensor-core path to fix the critical issue.

🧹 Nitpick comments (1)
flashinfer/utils.py (1)

240-255: Clarify expected device/shape for sinks (keep helper simple).

Current helper does dtype/contiguity only. Upstream validation (device move to q.device and shape == num_qo_heads) should happen where q is available (decode.py). Consider adding a short note in the docstring that sinks must be on the same CUDA device and length == num_qo_heads.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 54101e9 and 99067e4.

📒 Files selected for processing (6)
  • flashinfer/decode.py (6 hunks)
  • flashinfer/jit/attention/modules.py (2 hunks)
  • flashinfer/utils.py (1 hunks)
  • include/flashinfer/attention/decode.cuh (2 hunks)
  • include/flashinfer/attention/default_decode_params.cuh (6 hunks)
  • tests/attention/test_decode_sink_attention.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/decode.py (1)
flashinfer/utils.py (1)
  • _get_sink_buf (240-254)
tests/attention/test_decode_sink_attention.py (3)
tests/test_helpers/sink_attention_reference.py (1)
  • sink_attention_unified (39-402)
flashinfer/jit/env.py (1)
  • has_flashinfer_jit_cache (27-36)
flashinfer/decode.py (7)
  • BatchDecodeWithPagedKVCacheWrapper (586-1422)
  • plan (815-1106)
  • plan (1615-1738)
  • run (1136-1149)
  • run (1152-1165)
  • run (1167-1386)
  • run (1740-1864)
🪛 Ruff (0.14.4)
tests/attention/test_decode_sink_attention.py

1-1: The file is executable but no shebang is present

(EXE002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/jit/attention/modules.py (1)

763-765: Batch-decode JIT addition looks correct.

Adding maybe_s_aux alongside maybe_alibi_slopes for batch decode matches the new params and kernel usage.

Please confirm generated JIT bindings reflect the new pointer in the correct order (tensor args before scalar args).

include/flashinfer/attention/default_decode_params.cuh (1)

40-63: Param extension LGTM.

maybe_s_aux added and zero-initialized across constructors; matches kernel guards.

Comment on lines +470 to 472
["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names
["float", "float"], # additional_tensor_dtypes
[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Avoid breaking single-decode JIT: remove maybe_s_aux here (or plumb it through Python).

The single-decode Python wrapper doesn’t pass maybe_s_aux. Adding it here shifts the C++ run signature and will misalign subsequent args (logits_soft_cap becomes the second tensor param), breaking calls.

Minimal fix:

-        ["maybe_alibi_slopes", "maybe_s_aux"],  # additional_tensor_names
-        ["float", "float"],  # additional_tensor_dtypes
+        ["maybe_alibi_slopes"],  # additional_tensor_names
+        ["float"],  # additional_tensor_dtypes

If you want single-decode sink support, also update flashinfer/decode.py:get_single_decode_module wrappers to accept and pass maybe_s_aux in the correct position.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names
["float", "float"], # additional_tensor_dtypes
[
["maybe_alibi_slopes"], # additional_tensor_names
["float"], # additional_tensor_dtypes
[
🤖 Prompt for AI Agents
In flashinfer/jit/attention/modules.py around lines 470-472, adding
"maybe_s_aux" to the additional tensor names shifts the C++ JIT run signature
and breaks the single-decode wrapper which does not pass that arg; remove
"maybe_s_aux" from the additional_tensor_names and its dtype from
additional_tensor_dtypes so the signature remains unchanged, or alternatively
update flashinfer/decode.py:get_single_decode_module to accept a maybe_s_aux
parameter and forward it in the exact positional order expected by the JIT run
(ensure both the names and dtypes lists and all wrapper call sites stay
consistent).

Comment on lines +358 to +365
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix s_aux scaling to match logits path.

s[j] is scaled by variant.sm_scale_log2 before softmax. s_aux currently uses (s_aux - m) * LOG2_E, which mismatches and yields incorrect normalization. Scale s_aux with variant.sm_scale_log2 and drop LOG2_E.

Apply:

-  if constexpr (variant.use_softmax) {
-    if (params.maybe_s_aux != nullptr) {
-      constexpr float LOG2_E = 1.4426950408889634f;  // log2(e)
-      float s_aux_val = params.maybe_s_aux[qo_head_idx];
-      st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
-    }
-  }
+  if constexpr (variant.use_softmax) {
+    if (params.maybe_s_aux != nullptr) {
+      float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+      st_local.d += math::ptx_exp2(s_aux_scaled - st_local.m);
+    }
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
}
}
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
st_local.d += math::ptx_exp2(s_aux_scaled - st_local.m);
}
}
🤖 Prompt for AI Agents
In include/flashinfer/attention/decode.cuh around lines 358 to 365, the s_aux
contribution is being added using (s_aux - m) * LOG2_E which mismatches the
logits path scaling; change the computation to multiply (s_aux_val - st_local.m)
by variant.sm_scale_log2 and remove LOG2_E so the call becomes
math::ptx_exp2((s_aux_val - st_local.m) * variant.sm_scale_log2); keep the same
null check and use_softmax guard.

Comment on lines +601 to +607
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Same scaling fix for batch kernel.

Mirror the s_aux scaling correction here to keep behavior consistent across kernels.

-  if constexpr (variant.use_softmax) {
-    if (params.maybe_s_aux != nullptr) {
-      constexpr float LOG2_E = 1.4426950408889634f;  // log2(e)
-      float s_aux_val = params.maybe_s_aux[qo_head_idx];
-      st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
-    }
-  }
+  if constexpr (variant.use_softmax) {
+    if (params.maybe_s_aux != nullptr) {
+      float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+      st.d += math::ptx_exp2(s_aux_scaled - st.m);
+    }
+  }

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In include/flashinfer/attention/decode.cuh around lines 601-607, the
batch-kernel branch needs the same s_aux scaling fix as the non-batch path:
compute LOG2_E = 1.4426950408889634f, read s_aux_val =
params.maybe_s_aux[qo_head_idx], multiply (s_aux_val - st.m) by LOG2_E and pass
that to math::ptx_exp2, then add the result to st.d so the auxiliary scaling
matches the other kernel.

Add maybe_s_aux support to prefill template used for decode attention
when use_tensor_cores=True. Includes updates to params structures,
variant handling, JIT generation, Python wrappers, and comprehensive
test coverage with validation.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
include/flashinfer/attention/default_prefill_params.cuh (1)

1-1: Fix the clang-format errors before merging.

The pipeline indicates formatting violations. Please run clang-format on this file to resolve the pre-commit hook failure.

Run the following command to fix formatting:

#!/bin/bash
# Description: Format the file according to clang-format rules

clang-format -i include/flashinfer/attention/default_prefill_params.cuh
flashinfer/decode.py (3)

1170-1185: Sinks are silently ignored in JIT batch decode paths—confirmation required on whether this is intentional

The review comment is correct. Code inspection confirms:

  • Lines 1314-1343 (tensor-core path): When self._jit_module is not None, only list(args) extends run_args (line 1315); _get_sink_buf(sinks) is never added. In the non-JIT branch, _get_sink_buf(sinks) is correctly appended (line 1340).

  • Lines 1369-1381 (CUDA-core path): Identical pattern—JIT branch (line 1370) extends with list(args) only; non-JIT branch adds _get_sink_buf(sinks) (line 1374).

Since sinks is a named parameter (not part of *args), it is not captured by list(args) and is completely dropped for JIT paths. This creates a silent API inconsistency where the parameter is accepted but only respected in non-JIT code, risking hard-to-trace behavioral divergence.

The suggested fix (fail-fast with ValueError if JIT is enabled and sinks are provided, or thread sinks into JIT module calls) is appropriate.


1873-1932: Update sinks parameter type annotations from List[torch.Tensor] to torch.Tensor in three decode functions

The review correctly identifies a type annotation mismatch. Three public functions advertise sinks: Optional[List[torch.Tensor]] but the code internally expects a single torch.Tensor:

  • trtllm_batch_decode_with_kv_cache (line 2083)
  • trtllm_batch_decode_with_kv_cache_mla (line 2547)
  • xqa_batch_decode_with_kv_cache_mla (line 2715)

This causes a runtime error when callers pass a list, as the implementation calls sinks.reshape(num_kv_heads, -1) at line 2459, which requires a tensor.

Change all three function signatures from sinks: Optional[List[torch.Tensor]] = None to sinks: Optional[torch.Tensor] = None, and update the docstring at line 2137 to match.


354-392: Expose sinks in all overloads and document its semantics for single_decode_with_kv_cache

The first overload (with return_lse: Literal[False]) does not list sinks, while the second overload and implementation do. Type checkers will reject valid calls like single_decode_with_kv_cache(..., sinks=...) when return_lse=False. Additionally, the docstring does not describe sinks at all, nor document that it is currently only honored in the use_tensor_cores=True path and silently ignored otherwise.

Required changes:

  1. Add sinks to the first overload (line 354-371):

    return_lse: Literal[False] = False,
    sinks: Optional[torch.Tensor] = None,
    ) -> torch.Tensor: ...
  2. Document sinks in the docstring (add to Parameters section after return_lse):

    sinks : Optional[torch.Tensor]
        Optional per-head sink values of shape ``[num_qo_heads]`` added to the
        softmax denominator. Currently only supported when
        ``use_tensor_cores=True``; for the non-tensor-core decode kernel this
        argument is ignored.
    

Verify with mypy or pyright that both overloads now accept sinks calls. Consider raising a ValueError when sinks is not None and not use_tensor_cores to make the limitation explicit rather than silent.

♻️ Duplicate comments (1)
tests/attention/test_decode_sink_attention.py (1)

258-322: GQA sink-attention path is now explicitly covered

test_batch_decode_sink_attention_gqa sets num_qo_heads=16 and num_kv_heads=8, so the grouped‑query broadcast path is actually exercised, addressing earlier feedback. The sanity checks on shape, dtype, and NaN/Inf are appropriate for this focused test.

No action needed; this resolves the earlier concern about GQA coverage in this area.

🧹 Nitpick comments (3)
tests/attention/test_decode_sink_attention.py (3)

48-55: JIT warmup fixture is harmless but effectively a no-op

The warmup_jit fixture is autouse only when no flashinfer_jit_cache is present, but its body just yields. That’s fine (and cheap), though if the intention was to proactively trigger decode JIT compilation, you may later want to add a minimal decode call here; otherwise this is acceptable as is.

If you later add true JIT warmup logic, please ensure it’s guarded by CUDA availability checks to avoid spurious failures on non‑GPU test environments.


58-189: Strong coverage for batch decode + sinks across MHA/GQA and paging configs

test_batch_decode_with_sink_attention exercises a wide grid of (batch_size, kv_len, num_qo_heads, num_kv_heads, head_dim, page_size) with a GQA case and validates against the reference implementation using tolerances that are reasonable for bf16 and different reduction orders. The paged‑KV reconstruction logic from kv_data_fp32 into [B, kv_len, num_kv_heads, D] looks correct for both full and partial pages.

In a future pass, you might consider adding a non‑trivial window_left value and an HND layout here to exercise those branches more directly, but it isn’t strictly necessary for this PR.


407-408: Address Ruff EXE002 by dropping the __main__ block or adding a shebang

The if __name__ == "__main__": pytest.main(...) block combined with the executable bit triggers Ruff’s EXE002 (“file is executable but no shebang”). Since this is a pytest module, the simplest fix is to remove the __main__ block entirely and rely on pytest discovery:

-if __name__ == "__main__":
-    pytest.main([__file__, "-v"])

Alternatively, if you really want direct execution, add an appropriate shebang (#!/usr/bin/env python3) at the top and keep the block.

After adjusting, re-run pre-commit locally to ensure Ruff no longer reports EXE002 for this file.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 99067e4 and 58f66f8.

📒 Files selected for processing (6)
  • flashinfer/decode.py (9 hunks)
  • flashinfer/jit/attention/modules.py (2 hunks)
  • flashinfer/prefill.py (3 hunks)
  • include/flashinfer/attention/default_prefill_params.cuh (12 hunks)
  • include/flashinfer/attention/variants.cuh (1 hunks)
  • tests/attention/test_decode_sink_attention.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/decode.py (1)
flashinfer/utils.py (1)
  • _get_sink_buf (240-254)
tests/attention/test_decode_sink_attention.py (2)
tests/test_helpers/sink_attention_reference.py (1)
  • sink_attention_unified (39-402)
flashinfer/decode.py (12)
  • BatchDecodeWithPagedKVCacheWrapper (589-1425)
  • plan (818-1109)
  • plan (1618-1741)
  • run (1139-1152)
  • run (1155-1168)
  • run (1170-1389)
  • run (1743-1867)
  • single_decode_with_kv_cache (355-371)
  • single_decode_with_kv_cache (375-392)
  • single_decode_with_kv_cache (395-586)
  • use_tensor_cores (787-788)
  • use_tensor_cores (1591-1592)
🪛 GitHub Actions: pre-commit
flashinfer/jit/attention/modules.py

[error] 1-1: clang-format formatting failed. Hooks modified the file.

tests/attention/test_decode_sink_attention.py

[error] 1-1: ruff-format formatting failed. Hooks reformatted 2 files.

include/flashinfer/attention/default_prefill_params.cuh

[error] 1-1: clang-format formatting failed. Hooks modified the file.

🪛 Ruff (0.14.4)
tests/attention/test_decode_sink_attention.py

1-1: The file is executable but no shebang is present

(EXE002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (11)
include/flashinfer/attention/variants.cuh (1)

94-102: LGTM: Sink support implementation is correct.

The M-D update hook correctly implements the attention sink mechanism by augmenting the softmax denominator d with an exponential term derived from the auxiliary sink values. The conditional checks ensure the update only applies when softmax is enabled and sink values are provided.

flashinfer/jit/attention/modules.py (1)

519-520: Note: fa2 backend receives maybe_s_aux; fa3 does not.

The maybe_s_aux tensor is added only to the fa2 backend's single prefill module. The fa3 backend path (lines 530-544) does not include this tensor. Verify this is intentional and document if fa3 sink support is planned for future work.

flashinfer/prefill.py (2)

334-334: LGTM: maybe_s_aux correctly forwarded in fa2 backend.

The parameter is properly threaded through to the kernel invocation for the fa2 backend path, consistent with the JIT module changes.


280-280: Now I understand the issue clearly. Let me check one more thing to confirm if this is intentional design or an oversight:Based on my comprehensive verification:

maybe_s_aux is not exposed in the public API and is never passed from callers—this is intentional design, not a bug.

Attention sinks are conditionally selected for both prefill and decode when attention sinks are enabled, but the current implementation shows that support depends on upcoming FlashInfer features. The codebase confirms:

  1. No public API exposure: single_prefill_with_kv_cache (the public API) has no sink parameter
  2. Decode-only usage: The sinks tensor is passed to the attention calls only in the decode path, not prefill
  3. Internal parameter only: maybe_s_aux exists in run_single_prefill (internal JIT function) but module.run() doesn't pass it (remains None implicitly)

The behavior is consistent: sink token support is currently a decode-only feature. The parameter exists in the internal function signature but is never populated from the public API layer.

include/flashinfer/attention/default_prefill_params.cuh (3)

41-41: LGTM: maybe_s_aux consistently added to SinglePrefillParams.

The new member is properly declared, initialized to nullptr in the default constructor, and correctly initialized in the parameterized constructor following the same pattern as maybe_alibi_slopes.

Also applies to: 70-70, 92-92, 105-105


153-153: LGTM: maybe_s_aux consistently added to BatchPrefillRaggedParams.

The changes follow the same pattern as SinglePrefillParams with proper initialization in both constructors.

Also applies to: 198-198, 233-233, 251-251


307-307: LGTM: maybe_s_aux consistently added to BatchPrefillPagedParams.

All three parameter structs now consistently support the optional maybe_s_aux auxiliary tensor.

Also applies to: 344-344, 374-374, 387-387

flashinfer/decode.py (1)

44-68: Sink auxiliary buffer (maybe_s_aux) is wired consistently into batch decode custom op

The added maybe_s_aux parameter is threaded from the Python custom-op signature into run_func in the same relative position as alibi_slopes, which keeps the call layout coherent with the CUDA side. No issues here from the Python side; just ensure the C++ kernel and JIT templates were updated to the same signature ordering.

Please re-run the existing batch decode tests (including the new sink tests) once more after any low‑level changes to confirm there’s no ABI mismatch.

Also applies to: 219-247, 252-273, 276-299

tests/attention/test_decode_sink_attention.py (3)

27-45: Reference helper cleanly reuses the unified sink attention implementation

sink_attention_decode_ref wrapping sink_attention_unified in "incremental" mode gives a clear, single source of truth for expected decode behavior with sinks. This is a solid choice to keep the tests aligned with the reference implementation.

Rely on the existing tests/test_helpers/sink_attention_reference.py coverage to validate that the reference stays in sync with any future kernel changes.


191-256: Clear no-sink vs zero-sink equivalence test

test_batch_decode_without_sink_attention nicely checks that sinks=None and an all‑zero sink tensor produce numerically equivalent outputs within a tight tolerance, which is exactly the subtle regression this new plumbing could introduce. The setup (single configuration, NHD layout, fixed sizes) is sufficient as a smoke test.

Keep this test around as you evolve sink handling; it will quickly catch any accidental change in the “no‑sink” fast path semantics.


324-405: Single-token tensor-core decode with sinks is well validated against the reference

test_single_decode_sink_attention_tensor_cores:

  • Covers both kv_layout="NHD" and "HND" by reshaping/transposing into the reference’s [B, kv_len, num_kv_heads, D] format.
  • Uses bf16 and realistic sm_scale, with sinks in a logit‑scaled range.
  • Compares the tensor-core path output to sink_attention_decode_ref with tolerances consistent with the batch test, plus basic shape/dtype/NaN/Inf checks.

This gives good confidence that the single‑decode tensor‑core sink path matches the intended math.

If you later enable sinks for the non–tensor-core single decode kernel, adding a similar test that toggles use_tensor_cores=False will help guard that path as well.

Update single decode module's additional_tensor_names to include 'maybe_s_aux'
along with the corresponding 'float' entry in additional_tensor_dtypes.
This matches the batch decode definition and enables sink attention support
for single decode operations.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/attention/test_decode_sink_attention.py (2)

294-294: Inconsistent sink value generation.

This test uses torch.rand(num_qo_heads) * 5.0 (uniform distribution [0, 5.0]) while other tests use torch.randn(num_qo_heads) * 0.5 (normal distribution scaled by 0.5). For consistency and alignment with the documented scale expectations ("similar scale to logits"), consider using the same generation pattern.

Apply this diff:

-    sinks = torch.rand(num_qo_heads, device=device, dtype=torch.float32) * 5.0
+    sinks = torch.randn(num_qo_heads, device=device, dtype=torch.float32) * 0.5

258-322: Consider consolidating or enhancing this GQA test.

This test performs only basic sanity checks (shape, dtype, NaN/Inf) without validating against the reference implementation. Since test_batch_decode_with_sink_attention already includes GQA scenarios (32:8 ratio) with full numerical validation, this test may be redundant. Consider either:

  • Removing it in favor of the comprehensive test coverage above, or
  • Adding reference validation to ensure GQA-specific logic is thoroughly tested
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 58f66f8 and 62859a9.

📒 Files selected for processing (3)
  • flashinfer/jit/attention/modules.py (3 hunks)
  • include/flashinfer/attention/default_prefill_params.cuh (12 hunks)
  • tests/attention/test_decode_sink_attention.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_decode_sink_attention.py (3)
tests/test_helpers/sink_attention_reference.py (1)
  • sink_attention_unified (39-402)
flashinfer/jit/env.py (1)
  • has_flashinfer_jit_cache (27-36)
flashinfer/decode.py (12)
  • BatchDecodeWithPagedKVCacheWrapper (589-1425)
  • plan (818-1109)
  • plan (1618-1741)
  • run (1139-1152)
  • run (1155-1168)
  • run (1170-1389)
  • run (1743-1867)
  • single_decode_with_kv_cache (355-371)
  • single_decode_with_kv_cache (375-392)
  • single_decode_with_kv_cache (395-586)
  • use_tensor_cores (787-788)
  • use_tensor_cores (1591-1592)
🪛 Ruff (0.14.5)
tests/attention/test_decode_sink_attention.py

1-1: The file is executable but no shebang is present

(EXE002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (10)
include/flashinfer/attention/default_prefill_params.cuh (1)

41-41: LGTM! Consistent addition of sink auxiliary support across all parameter structures.

The maybe_s_aux member is correctly added to all three prefill parameter structures (SinglePrefillParams, BatchPrefillRaggedParams, BatchPrefillPagedParams) with proper:

  • Member declaration placement (after maybe_alibi_slopes)
  • Nullptr initialization in default constructors
  • Integration into parameterized constructors with correct initializer list ordering

The implementation follows existing patterns and maintains consistency across the codebase.

Also applies to: 70-70, 91-91, 104-104, 152-152, 197-197, 232-234, 249-249, 305-305, 342-342, 372-374, 384-384

tests/attention/test_decode_sink_attention.py (7)

1-26: LGTM!

The imports and copyright header are clean and all dependencies are properly utilized throughout the test module.


27-45: LGTM!

The reference wrapper correctly delegates to sink_attention_unified with appropriate decode-specific parameters (causal=True, mode="incremental").


48-55: Clarify the purpose of this empty fixture.

The warmup_jit fixture has autouse=not has_flashinfer_jit_cache() but its body only contains a yield with no warmup logic. The comment "This will be built on-demand during tests" suggests lazy compilation, but the fixture name implies it performs warmup. Either implement actual warmup logic or clarify why an empty fixture is needed here.


58-189: LGTM!

Excellent comprehensive test coverage for batch decode with sink attention. The test properly:

  • Covers various batch sizes, KV lengths, head configurations (including GQA), dimensions, and page sizes
  • Constructs paged KV caches with correct metadata (indptr, indices, last_page_len)
  • Converts paged format to reference format for both NHD and HND layouts
  • Validates results against the reference implementation with appropriate bf16-aware tolerances

191-255: LGTM!

This test appropriately validates that the no-sink path (sinks=None) produces results equivalent to zero sinks with relaxed tolerances that account for code path differences and bf16 precision.


324-408: LGTM!

This test provides excellent coverage for the single decode path with tensor cores, including:

  • Multiple KV lengths and head configurations (MHA and GQA)
  • Both NHD and HND layouts with proper format conversion for reference comparison
  • Numerical validation against the reference implementation with appropriate bf16 tolerances

411-412: LGTM!

Standard pytest entry point for direct execution.

flashinfer/jit/attention/modules.py (2)

470-471: Python wrappers properly accept and pass sinks parameter.

Verification confirms that Python wrappers have been updated to include the sinks parameter (e.g., lines 391, 411 in single_decode_with_kv_cache, and additional functions at lines 1183, 1896, 1990, 2055, 2083). The parameter is correctly forwarded to the JIT modules as maybe_s_aux at line 538 and throughout batch decode operations. The integration is consistent across all modified functions.


767-768: Based on the verification, the batch decode Python wrappers have been successfully updated to accept and pass the sinks parameter:

Evidence:

  • The trtllm_batch_decode_with_kv_cache function signature includes sinks: List[torch.Tensor] | None = None
  • Line 1183 in flashinfer/decode.py shows the run() method accepting sinks: Optional[torch.Tensor] = None
  • Lines 1340 and 1374 demonstrate the sinks being processed via _get_sink_buf(sinks) in both paged and non-paged execution paths
  • Test files confirm successful usage: wrapper.run(q, kv_data, sinks=sinks) is called across multiple test cases with the parameter properly forwarded

The changes to the JIT modules are consistent with the wrapper implementations—all three decode variants (single decode, batch decode with paged KV, and batch decode with padded KV) now uniformly support the maybe_s_aux / sinks parameter.

Comment on lines +519 to +524
additional_tensor_names = [
"maybe_custom_mask",
"maybe_alibi_slopes",
"maybe_s_aux",
]
additional_tensor_dtypes = ["uint8_t", "float", "float"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify that single prefill Python wrappers accept and pass maybe_s_aux parameter.

The single prefill module (fa2 backend) now includes maybe_s_aux as the third additional tensor parameter. Ensure that Python wrappers for single prefill operations (e.g., flashinfer.single_prefill_with_kv_cache, flashinfer/prefill.py:get_single_prefill_module) have been updated to accept and forward the sinks parameter.

Run the following script to verify the prefill wrapper signatures:


🏁 Script executed:

#!/bin/bash
# Description: Check if Python wrappers for single prefill have been updated to handle sinks parameter

echo "=== Checking single_prefill_with_kv_cache function signature ==="
rg -n -A5 'def single_prefill_with_kv_cache' --type=py

echo ""
echo "=== Checking get_single_prefill_module calls ==="
rg -n -B3 -A10 'get_single_prefill_module' --type=py

echo ""
echo "=== Checking for sinks parameter in prefill.py ==="
rg -n -C5 'sinks' python/flashinfer/prefill.py 2>/dev/null || rg -n -C5 'sinks' flashinfer/prefill.py

Length of output: 10216


🏁 Script executed:

#!/bin/bash
# Get full function signatures for single_prefill_with_kv_cache functions

echo "=== Full signature of single_prefill_with_kv_cache at line 879 ==="
sed -n '879,920p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 914 ==="
sed -n '914,960p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 939 ==="
sed -n '939,985p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 963 ==="
sed -n '963,1010p' flashinfer/prefill.py

echo ""
echo "=== Checking how module is called in single prefill (around line 1164) ==="
sed -n '1160,1210p' flashinfer/prefill.py

Length of output: 671


🏁 Script executed:

#!/bin/bash
# Get full function signatures for single_prefill_with_kv_cache functions

echo "=== Full signature of single_prefill_with_kv_cache at line 879 ==="
sed -n '879,920p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 914 ==="
sed -n '914,960p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 939 ==="
sed -n '939,985p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 963 ==="
sed -n '963,1010p' flashinfer/prefill.py

echo ""
echo "=== Checking how module is called in single prefill (around line 1164) ==="
sed -n '1160,1210p' flashinfer/prefill.py

echo ""
echo "=== Checking run_single_prefill function signature and how it calls module ==="
sed -n '268,310p' flashinfer/prefill.py

Length of output: 9560


🏁 Script executed:

#!/bin/bash
# Check the complete module.run() call in single_prefill_with_kv_cache implementation

echo "=== Complete module.run() call (lines 1176-1195) ==="
sed -n '1176,1195p' flashinfer/prefill.py

echo ""
echo "=== Verify run_single_prefill accepts maybe_s_aux (around line 278) ==="
sed -n '278,290p' flashinfer/prefill.py

echo ""
echo "=== Check if run_single_prefill actually uses maybe_s_aux in fa2 backend ==="
sed -n '310,350p' flashinfer/prefill.py

Length of output: 2182


Add maybe_s_aux parameter to single prefill Python wrapper and pass it to module.run().

The single_prefill_with_kv_cache function does not accept or pass the maybe_s_aux parameter that the underlying generated C++ module expects. The custom operation run_single_prefill (line 278) accepts maybe_s_aux and passes it to the fa2 backend (line 329), but the Python wrapper never provides it. This creates a parameter mismatch.

Update single_prefill_with_kv_cache:

  1. Add maybe_s_aux: Optional[torch.Tensor] = None parameter (matching batch prefill pattern)
  2. Pass it to module.run() after maybe_alibi_slopes and before logits_soft_cap (position 11)
🤖 Prompt for AI Agents
In flashinfer/jit/attention/modules.py around lines 519 to 524, the Python
wrapper single_prefill_with_kv_cache does not accept or forward the maybe_s_aux
tensor that the generated C++ module expects; add a parameter maybe_s_aux:
Optional[torch.Tensor] = None to the function signature (matching the batch
prefill pattern) and include that variable in the module.run(...) call argument
list immediately after maybe_alibi_slopes and before logits_soft_cap (i.e., as
the 11th positional argument) so the wrapper matches the underlying
run_single_prefill/fa2 backend usage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants