-
Notifications
You must be signed in to change notification settings - Fork 571
feat: add sink to flashinfer decode #2087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Nitin Gupta <ngupta@groq.com> Co-authored-by: Duncan Moss <djmmoss@gmail.com>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds optional per-head sink auxiliary ( Changes
Sequence DiagramsequenceDiagram
autonumber
actor User
participant PyAPI as Python API (decode/prefill)
participant Utils as _get_sink_buf
participant JIT as JIT Module
participant Kernel as CUDA Kernel
User->>PyAPI: run_* (q, ..., maybe_s_aux / sinks)
PyAPI->>Utils: _get_sink_buf(maybe_s_aux)
Utils-->>PyAPI: sink_buf or None
alt JIT path
PyAPI->>JIT: invoke module with maybe_s_aux tensor
JIT->>Kernel: kernel invoked (has maybe_s_aux)
else non-JIT path
PyAPI->>Kernel: call kernel with sink_ptr (from _get_sink_buf)
end
Kernel->>Kernel: if use_softmax && maybe_s_aux:\n compute exp2((s_aux - max)*LOG2_E) and add to denom
Kernel-->>PyAPI: attention outputs
PyAPI-->>User: return result
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @djmmoss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates 'sink attention' into FlashInfer's decode functionality, a mechanism vital for managing context length in large language models by allowing a learnable sink token to influence the attention softmax. The changes span the Python API, JIT compilation modules, and underlying CUDA kernels, ensuring a cohesive implementation. A new utility function facilitates proper handling of sink tensors, and a comprehensive suite of unit tests has been added to validate the feature's correctness and robustness across various configurations, including grouped query attention. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for sink attention in the flashinfer decode functionality. The changes are well-structured, touching upon Python bindings, JIT compilation, and the core C++ CUDA kernels. A comprehensive suite of tests has been added to validate the new feature. The implementation appears correct. I have one suggestion regarding the C++ kernel code to improve maintainability by reducing code duplication.
| // Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed | ||
| if constexpr (variant.use_softmax) { | ||
| if (params.maybe_s_aux != nullptr) { | ||
| constexpr float LOG2_E = 1.4426950408889634f; // log2(e) | ||
| float s_aux_val = params.maybe_s_aux[qo_head_idx]; | ||
| st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic for adding the sink contribution is duplicated in BatchDecodeWithPagedKVCacheDevice (lines 600-607). To improve maintainability and reduce code duplication, consider extracting this block into a helper function.
Also, the constant LOG2_E is defined inline here and in the other location. It would be better to define it once at the top of the file in an anonymous namespace to avoid magic numbers and ensure consistency.
For example, you could add at the top of the file:
namespace flashinfer {
namespace { // anonymous namespace
static constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
template <typename AttentionVariant, typename State, typename Params>
__device__ __forceinline__ void AddSinkContribution(AttentionVariant variant, State& st,
const Params& params,
uint32_t qo_head_idx) {
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
}
}
}
} // anonymous namespace
// ... rest of the fileThen you could replace this block and the one in BatchDecodeWithPagedKVCacheDevice with a call to this helper function:
AddSinkContribution(variant, st_local, params, qo_head_idx);|
Note: Adding attention sink support to the CUDA cores template (in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/decode.py (2)
1260-1285: Validate and prepare sinks once (device + shape).Before building run args, validate sinks length matches num_qo_heads and move to q.device. Prepare a buffer to reuse in both branches.
if rope_theta is None: rope_theta = 1e4 + # prepare sinks (optional) + sink_buf = None + if sinks is not None: + if sinks.dim() != 1 or sinks.numel() != q.shape[1]: + raise ValueError(f"sinks must be 1D with length num_qo_heads={q.shape[1]}, got {tuple(sinks.shape)}") + sink_buf = _get_sink_buf(sinks.to(q.device))
1291-1338: Remove sinks argument from tensor-core FA2 paged_run path only; non-tensor-core fix in review is incomplete.The tensor-core issue is confirmed critical: FA2's paged_run kernel does not pass the sinks argument to the actual kernel, yet line 1337 passes
_get_sink_buf(sinks). This will fail at runtime due to ABI mismatch.However, the review's non-tensor-core fix is incomplete. The non-tensor-core decode module's run() method does accept sinks, but the review references a
sink_bufvariable that doesn't exist in the codebase. Line 1371 currently calls_get_sink_buf(sinks)inline. To apply the review's intent, either:
- Pre-compute
sink_buf = _get_sink_buf(sinks)before the if/else block, then reference it, or- Keep the inline
_get_sink_buf(sinks)call in non-tensor-core as-isRemove
_get_sink_buf(sinks)from line 1337 in the tensor-core path to fix the critical issue.
🧹 Nitpick comments (1)
flashinfer/utils.py (1)
240-255: Clarify expected device/shape for sinks (keep helper simple).Current helper does dtype/contiguity only. Upstream validation (device move to q.device and shape == num_qo_heads) should happen where q is available (decode.py). Consider adding a short note in the docstring that sinks must be on the same CUDA device and length == num_qo_heads.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
flashinfer/decode.py(6 hunks)flashinfer/jit/attention/modules.py(2 hunks)flashinfer/utils.py(1 hunks)include/flashinfer/attention/decode.cuh(2 hunks)include/flashinfer/attention/default_decode_params.cuh(6 hunks)tests/attention/test_decode_sink_attention.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/decode.py (1)
flashinfer/utils.py (1)
_get_sink_buf(240-254)
tests/attention/test_decode_sink_attention.py (3)
tests/test_helpers/sink_attention_reference.py (1)
sink_attention_unified(39-402)flashinfer/jit/env.py (1)
has_flashinfer_jit_cache(27-36)flashinfer/decode.py (7)
BatchDecodeWithPagedKVCacheWrapper(586-1422)plan(815-1106)plan(1615-1738)run(1136-1149)run(1152-1165)run(1167-1386)run(1740-1864)
🪛 Ruff (0.14.4)
tests/attention/test_decode_sink_attention.py
1-1: The file is executable but no shebang is present
(EXE002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/jit/attention/modules.py (1)
763-765: Batch-decode JIT addition looks correct.Adding maybe_s_aux alongside maybe_alibi_slopes for batch decode matches the new params and kernel usage.
Please confirm generated JIT bindings reflect the new pointer in the correct order (tensor args before scalar args).
include/flashinfer/attention/default_decode_params.cuh (1)
40-63: Param extension LGTM.maybe_s_aux added and zero-initialized across constructors; matches kernel guards.
| ["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names | ||
| ["float", "float"], # additional_tensor_dtypes | ||
| [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid breaking single-decode JIT: remove maybe_s_aux here (or plumb it through Python).
The single-decode Python wrapper doesn’t pass maybe_s_aux. Adding it here shifts the C++ run signature and will misalign subsequent args (logits_soft_cap becomes the second tensor param), breaking calls.
Minimal fix:
- ["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names
- ["float", "float"], # additional_tensor_dtypes
+ ["maybe_alibi_slopes"], # additional_tensor_names
+ ["float"], # additional_tensor_dtypesIf you want single-decode sink support, also update flashinfer/decode.py:get_single_decode_module wrappers to accept and pass maybe_s_aux in the correct position.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names | |
| ["float", "float"], # additional_tensor_dtypes | |
| [ | |
| ["maybe_alibi_slopes"], # additional_tensor_names | |
| ["float"], # additional_tensor_dtypes | |
| [ |
🤖 Prompt for AI Agents
In flashinfer/jit/attention/modules.py around lines 470-472, adding
"maybe_s_aux" to the additional tensor names shifts the C++ JIT run signature
and breaks the single-decode wrapper which does not pass that arg; remove
"maybe_s_aux" from the additional_tensor_names and its dtype from
additional_tensor_dtypes so the signature remains unchanged, or alternatively
update flashinfer/decode.py:get_single_decode_module to accept a maybe_s_aux
parameter and forward it in the exact positional order expected by the JIT run
(ensure both the names and dtypes lists and all wrapper call sites stay
consistent).
| // Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed | ||
| if constexpr (variant.use_softmax) { | ||
| if (params.maybe_s_aux != nullptr) { | ||
| constexpr float LOG2_E = 1.4426950408889634f; // log2(e) | ||
| float s_aux_val = params.maybe_s_aux[qo_head_idx]; | ||
| st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix s_aux scaling to match logits path.
s[j] is scaled by variant.sm_scale_log2 before softmax. s_aux currently uses (s_aux - m) * LOG2_E, which mismatches and yields incorrect normalization. Scale s_aux with variant.sm_scale_log2 and drop LOG2_E.
Apply:
- if constexpr (variant.use_softmax) {
- if (params.maybe_s_aux != nullptr) {
- constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
- float s_aux_val = params.maybe_s_aux[qo_head_idx];
- st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
- }
- }
+ if constexpr (variant.use_softmax) {
+ if (params.maybe_s_aux != nullptr) {
+ float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+ st_local.d += math::ptx_exp2(s_aux_scaled - st_local.m);
+ }
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed | |
| if constexpr (variant.use_softmax) { | |
| if (params.maybe_s_aux != nullptr) { | |
| constexpr float LOG2_E = 1.4426950408889634f; // log2(e) | |
| float s_aux_val = params.maybe_s_aux[qo_head_idx]; | |
| st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E); | |
| } | |
| } | |
| // Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed | |
| if constexpr (variant.use_softmax) { | |
| if (params.maybe_s_aux != nullptr) { | |
| float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2; | |
| st_local.d += math::ptx_exp2(s_aux_scaled - st_local.m); | |
| } | |
| } |
🤖 Prompt for AI Agents
In include/flashinfer/attention/decode.cuh around lines 358 to 365, the s_aux
contribution is being added using (s_aux - m) * LOG2_E which mismatches the
logits path scaling; change the computation to multiply (s_aux_val - st_local.m)
by variant.sm_scale_log2 and remove LOG2_E so the call becomes
math::ptx_exp2((s_aux_val - st_local.m) * variant.sm_scale_log2); keep the same
null check and use_softmax guard.
| if constexpr (variant.use_softmax) { | ||
| if (params.maybe_s_aux != nullptr) { | ||
| constexpr float LOG2_E = 1.4426950408889634f; // log2(e) | ||
| float s_aux_val = params.maybe_s_aux[qo_head_idx]; | ||
| st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same scaling fix for batch kernel.
Mirror the s_aux scaling correction here to keep behavior consistent across kernels.
- if constexpr (variant.use_softmax) {
- if (params.maybe_s_aux != nullptr) {
- constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
- float s_aux_val = params.maybe_s_aux[qo_head_idx];
- st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
- }
- }
+ if constexpr (variant.use_softmax) {
+ if (params.maybe_s_aux != nullptr) {
+ float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+ st.d += math::ptx_exp2(s_aux_scaled - st.m);
+ }
+ }Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In include/flashinfer/attention/decode.cuh around lines 601-607, the
batch-kernel branch needs the same s_aux scaling fix as the non-batch path:
compute LOG2_E = 1.4426950408889634f, read s_aux_val =
params.maybe_s_aux[qo_head_idx], multiply (s_aux_val - st.m) by LOG2_E and pass
that to math::ptx_exp2, then add the result to st.d so the auxiliary scaling
matches the other kernel.
Add maybe_s_aux support to prefill template used for decode attention when use_tensor_cores=True. Includes updates to params structures, variant handling, JIT generation, Python wrappers, and comprehensive test coverage with validation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
include/flashinfer/attention/default_prefill_params.cuh (1)
1-1: Fix the clang-format errors before merging.The pipeline indicates formatting violations. Please run clang-format on this file to resolve the pre-commit hook failure.
Run the following command to fix formatting:
#!/bin/bash # Description: Format the file according to clang-format rules clang-format -i include/flashinfer/attention/default_prefill_params.cuhflashinfer/decode.py (3)
1170-1185: Sinks are silently ignored in JIT batch decode paths—confirmation required on whether this is intentionalThe review comment is correct. Code inspection confirms:
Lines 1314-1343 (tensor-core path): When
self._jit_module is not None, onlylist(args)extendsrun_args(line 1315);_get_sink_buf(sinks)is never added. In the non-JIT branch,_get_sink_buf(sinks)is correctly appended (line 1340).Lines 1369-1381 (CUDA-core path): Identical pattern—JIT branch (line 1370) extends with
list(args)only; non-JIT branch adds_get_sink_buf(sinks)(line 1374).Since
sinksis a named parameter (not part of*args), it is not captured bylist(args)and is completely dropped for JIT paths. This creates a silent API inconsistency where the parameter is accepted but only respected in non-JIT code, risking hard-to-trace behavioral divergence.The suggested fix (fail-fast with
ValueErrorif JIT is enabled and sinks are provided, or thread sinks into JIT module calls) is appropriate.
1873-1932: Updatesinksparameter type annotations fromList[torch.Tensor]totorch.Tensorin three decode functionsThe review correctly identifies a type annotation mismatch. Three public functions advertise
sinks: Optional[List[torch.Tensor]]but the code internally expects a singletorch.Tensor:
trtllm_batch_decode_with_kv_cache(line 2083)trtllm_batch_decode_with_kv_cache_mla(line 2547)xqa_batch_decode_with_kv_cache_mla(line 2715)This causes a runtime error when callers pass a list, as the implementation calls
sinks.reshape(num_kv_heads, -1)at line 2459, which requires a tensor.Change all three function signatures from
sinks: Optional[List[torch.Tensor]] = Nonetosinks: Optional[torch.Tensor] = None, and update the docstring at line 2137 to match.
354-392: Exposesinksin all overloads and document its semantics forsingle_decode_with_kv_cacheThe first overload (with
return_lse: Literal[False]) does not listsinks, while the second overload and implementation do. Type checkers will reject valid calls likesingle_decode_with_kv_cache(..., sinks=...)whenreturn_lse=False. Additionally, the docstring does not describesinksat all, nor document that it is currently only honored in theuse_tensor_cores=Truepath and silently ignored otherwise.Required changes:
Add
sinksto the first overload (line 354-371):return_lse: Literal[False] = False, sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: ...Document
sinksin the docstring (add to Parameters section afterreturn_lse):sinks : Optional[torch.Tensor] Optional per-head sink values of shape ``[num_qo_heads]`` added to the softmax denominator. Currently only supported when ``use_tensor_cores=True``; for the non-tensor-core decode kernel this argument is ignored.Verify with mypy or pyright that both overloads now accept
sinkscalls. Consider raising aValueErrorwhensinks is not None and not use_tensor_coresto make the limitation explicit rather than silent.
♻️ Duplicate comments (1)
tests/attention/test_decode_sink_attention.py (1)
258-322: GQA sink-attention path is now explicitly covered
test_batch_decode_sink_attention_gqasetsnum_qo_heads=16andnum_kv_heads=8, so the grouped‑query broadcast path is actually exercised, addressing earlier feedback. The sanity checks on shape, dtype, and NaN/Inf are appropriate for this focused test.No action needed; this resolves the earlier concern about GQA coverage in this area.
🧹 Nitpick comments (3)
tests/attention/test_decode_sink_attention.py (3)
48-55: JIT warmup fixture is harmless but effectively a no-opThe
warmup_jitfixture is autouse only when noflashinfer_jit_cacheis present, but its body justyields. That’s fine (and cheap), though if the intention was to proactively trigger decode JIT compilation, you may later want to add a minimal decode call here; otherwise this is acceptable as is.If you later add true JIT warmup logic, please ensure it’s guarded by CUDA availability checks to avoid spurious failures on non‑GPU test environments.
58-189: Strong coverage for batch decode + sinks across MHA/GQA and paging configs
test_batch_decode_with_sink_attentionexercises a wide grid of(batch_size, kv_len, num_qo_heads, num_kv_heads, head_dim, page_size)with a GQA case and validates against the reference implementation using tolerances that are reasonable for bf16 and different reduction orders. The paged‑KV reconstruction logic fromkv_data_fp32into[B, kv_len, num_kv_heads, D]looks correct for both full and partial pages.In a future pass, you might consider adding a non‑trivial
window_leftvalue and anHNDlayout here to exercise those branches more directly, but it isn’t strictly necessary for this PR.
407-408: Address Ruff EXE002 by dropping the__main__block or adding a shebangThe
if __name__ == "__main__": pytest.main(...)block combined with the executable bit triggers Ruff’s EXE002 (“file is executable but no shebang”). Since this is a pytest module, the simplest fix is to remove the__main__block entirely and rely onpytestdiscovery:-if __name__ == "__main__": - pytest.main([__file__, "-v"])Alternatively, if you really want direct execution, add an appropriate shebang (
#!/usr/bin/env python3) at the top and keep the block.After adjusting, re-run
pre-commitlocally to ensure Ruff no longer reports EXE002 for this file.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
flashinfer/decode.py(9 hunks)flashinfer/jit/attention/modules.py(2 hunks)flashinfer/prefill.py(3 hunks)include/flashinfer/attention/default_prefill_params.cuh(12 hunks)include/flashinfer/attention/variants.cuh(1 hunks)tests/attention/test_decode_sink_attention.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/decode.py (1)
flashinfer/utils.py (1)
_get_sink_buf(240-254)
tests/attention/test_decode_sink_attention.py (2)
tests/test_helpers/sink_attention_reference.py (1)
sink_attention_unified(39-402)flashinfer/decode.py (12)
BatchDecodeWithPagedKVCacheWrapper(589-1425)plan(818-1109)plan(1618-1741)run(1139-1152)run(1155-1168)run(1170-1389)run(1743-1867)single_decode_with_kv_cache(355-371)single_decode_with_kv_cache(375-392)single_decode_with_kv_cache(395-586)use_tensor_cores(787-788)use_tensor_cores(1591-1592)
🪛 GitHub Actions: pre-commit
flashinfer/jit/attention/modules.py
[error] 1-1: clang-format formatting failed. Hooks modified the file.
tests/attention/test_decode_sink_attention.py
[error] 1-1: ruff-format formatting failed. Hooks reformatted 2 files.
include/flashinfer/attention/default_prefill_params.cuh
[error] 1-1: clang-format formatting failed. Hooks modified the file.
🪛 Ruff (0.14.4)
tests/attention/test_decode_sink_attention.py
1-1: The file is executable but no shebang is present
(EXE002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (11)
include/flashinfer/attention/variants.cuh (1)
94-102: LGTM: Sink support implementation is correct.The M-D update hook correctly implements the attention sink mechanism by augmenting the softmax denominator
dwith an exponential term derived from the auxiliary sink values. The conditional checks ensure the update only applies when softmax is enabled and sink values are provided.flashinfer/jit/attention/modules.py (1)
519-520: Note: fa2 backend receives maybe_s_aux; fa3 does not.The
maybe_s_auxtensor is added only to the fa2 backend's single prefill module. The fa3 backend path (lines 530-544) does not include this tensor. Verify this is intentional and document if fa3 sink support is planned for future work.flashinfer/prefill.py (2)
334-334: LGTM: maybe_s_aux correctly forwarded in fa2 backend.The parameter is properly threaded through to the kernel invocation for the fa2 backend path, consistent with the JIT module changes.
280-280: Now I understand the issue clearly. Let me check one more thing to confirm if this is intentional design or an oversight:Based on my comprehensive verification:
maybe_s_auxis not exposed in the public API and is never passed from callers—this is intentional design, not a bug.Attention sinks are conditionally selected for both prefill and decode when attention sinks are enabled, but the current implementation shows that support depends on upcoming FlashInfer features. The codebase confirms:
- No public API exposure:
single_prefill_with_kv_cache(the public API) has no sink parameter- Decode-only usage: The sinks tensor is passed to the attention calls only in the decode path, not prefill
- Internal parameter only:
maybe_s_auxexists inrun_single_prefill(internal JIT function) butmodule.run()doesn't pass it (remainsNoneimplicitly)The behavior is consistent: sink token support is currently a decode-only feature. The parameter exists in the internal function signature but is never populated from the public API layer.
include/flashinfer/attention/default_prefill_params.cuh (3)
41-41: LGTM: maybe_s_aux consistently added to SinglePrefillParams.The new member is properly declared, initialized to
nullptrin the default constructor, and correctly initialized in the parameterized constructor following the same pattern asmaybe_alibi_slopes.Also applies to: 70-70, 92-92, 105-105
153-153: LGTM: maybe_s_aux consistently added to BatchPrefillRaggedParams.The changes follow the same pattern as SinglePrefillParams with proper initialization in both constructors.
Also applies to: 198-198, 233-233, 251-251
307-307: LGTM: maybe_s_aux consistently added to BatchPrefillPagedParams.All three parameter structs now consistently support the optional
maybe_s_auxauxiliary tensor.Also applies to: 344-344, 374-374, 387-387
flashinfer/decode.py (1)
44-68: Sink auxiliary buffer (maybe_s_aux) is wired consistently into batch decode custom opThe added
maybe_s_auxparameter is threaded from the Python custom-op signature intorun_funcin the same relative position asalibi_slopes, which keeps the call layout coherent with the CUDA side. No issues here from the Python side; just ensure the C++ kernel and JIT templates were updated to the same signature ordering.Please re-run the existing batch decode tests (including the new sink tests) once more after any low‑level changes to confirm there’s no ABI mismatch.
Also applies to: 219-247, 252-273, 276-299
tests/attention/test_decode_sink_attention.py (3)
27-45: Reference helper cleanly reuses the unified sink attention implementation
sink_attention_decode_refwrappingsink_attention_unifiedin"incremental"mode gives a clear, single source of truth for expected decode behavior with sinks. This is a solid choice to keep the tests aligned with the reference implementation.Rely on the existing
tests/test_helpers/sink_attention_reference.pycoverage to validate that the reference stays in sync with any future kernel changes.
191-256: Clear no-sink vs zero-sink equivalence test
test_batch_decode_without_sink_attentionnicely checks thatsinks=Noneand an all‑zero sink tensor produce numerically equivalent outputs within a tight tolerance, which is exactly the subtle regression this new plumbing could introduce. The setup (single configuration, NHD layout, fixed sizes) is sufficient as a smoke test.Keep this test around as you evolve sink handling; it will quickly catch any accidental change in the “no‑sink” fast path semantics.
324-405: Single-token tensor-core decode with sinks is well validated against the reference
test_single_decode_sink_attention_tensor_cores:
- Covers both
kv_layout="NHD"and"HND"by reshaping/transposing into the reference’s[B, kv_len, num_kv_heads, D]format.- Uses bf16 and realistic
sm_scale, with sinks in a logit‑scaled range.- Compares the tensor-core path output to
sink_attention_decode_refwith tolerances consistent with the batch test, plus basic shape/dtype/NaN/Inf checks.This gives good confidence that the single‑decode tensor‑core sink path matches the intended math.
If you later enable sinks for the non–tensor-core single decode kernel, adding a similar test that toggles
use_tensor_cores=Falsewill help guard that path as well.
Update single decode module's additional_tensor_names to include 'maybe_s_aux' along with the corresponding 'float' entry in additional_tensor_dtypes. This matches the batch decode definition and enables sink attention support for single decode operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/attention/test_decode_sink_attention.py (2)
294-294: Inconsistent sink value generation.This test uses
torch.rand(num_qo_heads) * 5.0(uniform distribution [0, 5.0]) while other tests usetorch.randn(num_qo_heads) * 0.5(normal distribution scaled by 0.5). For consistency and alignment with the documented scale expectations ("similar scale to logits"), consider using the same generation pattern.Apply this diff:
- sinks = torch.rand(num_qo_heads, device=device, dtype=torch.float32) * 5.0 + sinks = torch.randn(num_qo_heads, device=device, dtype=torch.float32) * 0.5
258-322: Consider consolidating or enhancing this GQA test.This test performs only basic sanity checks (shape, dtype, NaN/Inf) without validating against the reference implementation. Since
test_batch_decode_with_sink_attentionalready includes GQA scenarios (32:8 ratio) with full numerical validation, this test may be redundant. Consider either:
- Removing it in favor of the comprehensive test coverage above, or
- Adding reference validation to ensure GQA-specific logic is thoroughly tested
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/jit/attention/modules.py(3 hunks)include/flashinfer/attention/default_prefill_params.cuh(12 hunks)tests/attention/test_decode_sink_attention.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_decode_sink_attention.py (3)
tests/test_helpers/sink_attention_reference.py (1)
sink_attention_unified(39-402)flashinfer/jit/env.py (1)
has_flashinfer_jit_cache(27-36)flashinfer/decode.py (12)
BatchDecodeWithPagedKVCacheWrapper(589-1425)plan(818-1109)plan(1618-1741)run(1139-1152)run(1155-1168)run(1170-1389)run(1743-1867)single_decode_with_kv_cache(355-371)single_decode_with_kv_cache(375-392)single_decode_with_kv_cache(395-586)use_tensor_cores(787-788)use_tensor_cores(1591-1592)
🪛 Ruff (0.14.5)
tests/attention/test_decode_sink_attention.py
1-1: The file is executable but no shebang is present
(EXE002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (10)
include/flashinfer/attention/default_prefill_params.cuh (1)
41-41: LGTM! Consistent addition of sink auxiliary support across all parameter structures.The
maybe_s_auxmember is correctly added to all three prefill parameter structures (SinglePrefillParams,BatchPrefillRaggedParams,BatchPrefillPagedParams) with proper:
- Member declaration placement (after
maybe_alibi_slopes)- Nullptr initialization in default constructors
- Integration into parameterized constructors with correct initializer list ordering
The implementation follows existing patterns and maintains consistency across the codebase.
Also applies to: 70-70, 91-91, 104-104, 152-152, 197-197, 232-234, 249-249, 305-305, 342-342, 372-374, 384-384
tests/attention/test_decode_sink_attention.py (7)
1-26: LGTM!The imports and copyright header are clean and all dependencies are properly utilized throughout the test module.
27-45: LGTM!The reference wrapper correctly delegates to
sink_attention_unifiedwith appropriate decode-specific parameters (causal=True,mode="incremental").
48-55: Clarify the purpose of this empty fixture.The
warmup_jitfixture hasautouse=not has_flashinfer_jit_cache()but its body only contains ayieldwith no warmup logic. The comment "This will be built on-demand during tests" suggests lazy compilation, but the fixture name implies it performs warmup. Either implement actual warmup logic or clarify why an empty fixture is needed here.
58-189: LGTM!Excellent comprehensive test coverage for batch decode with sink attention. The test properly:
- Covers various batch sizes, KV lengths, head configurations (including GQA), dimensions, and page sizes
- Constructs paged KV caches with correct metadata (indptr, indices, last_page_len)
- Converts paged format to reference format for both NHD and HND layouts
- Validates results against the reference implementation with appropriate bf16-aware tolerances
191-255: LGTM!This test appropriately validates that the no-sink path (
sinks=None) produces results equivalent to zero sinks with relaxed tolerances that account for code path differences and bf16 precision.
324-408: LGTM!This test provides excellent coverage for the single decode path with tensor cores, including:
- Multiple KV lengths and head configurations (MHA and GQA)
- Both NHD and HND layouts with proper format conversion for reference comparison
- Numerical validation against the reference implementation with appropriate bf16 tolerances
411-412: LGTM!Standard pytest entry point for direct execution.
flashinfer/jit/attention/modules.py (2)
470-471: Python wrappers properly accept and passsinksparameter.Verification confirms that Python wrappers have been updated to include the
sinksparameter (e.g., lines 391, 411 insingle_decode_with_kv_cache, and additional functions at lines 1183, 1896, 1990, 2055, 2083). The parameter is correctly forwarded to the JIT modules asmaybe_s_auxat line 538 and throughout batch decode operations. The integration is consistent across all modified functions.
767-768: Based on the verification, the batch decode Python wrappers have been successfully updated to accept and pass thesinksparameter:Evidence:
- The
trtllm_batch_decode_with_kv_cachefunction signature includessinks: List[torch.Tensor] | None = None- Line 1183 in
flashinfer/decode.pyshows therun()method acceptingsinks: Optional[torch.Tensor] = None- Lines 1340 and 1374 demonstrate the sinks being processed via
_get_sink_buf(sinks)in both paged and non-paged execution paths- Test files confirm successful usage:
wrapper.run(q, kv_data, sinks=sinks)is called across multiple test cases with the parameter properly forwardedThe changes to the JIT modules are consistent with the wrapper implementations—all three decode variants (single decode, batch decode with paged KV, and batch decode with padded KV) now uniformly support the
maybe_s_aux/sinksparameter.
| additional_tensor_names = [ | ||
| "maybe_custom_mask", | ||
| "maybe_alibi_slopes", | ||
| "maybe_s_aux", | ||
| ] | ||
| additional_tensor_dtypes = ["uint8_t", "float", "float"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify that single prefill Python wrappers accept and pass maybe_s_aux parameter.
The single prefill module (fa2 backend) now includes maybe_s_aux as the third additional tensor parameter. Ensure that Python wrappers for single prefill operations (e.g., flashinfer.single_prefill_with_kv_cache, flashinfer/prefill.py:get_single_prefill_module) have been updated to accept and forward the sinks parameter.
Run the following script to verify the prefill wrapper signatures:
🏁 Script executed:
#!/bin/bash
# Description: Check if Python wrappers for single prefill have been updated to handle sinks parameter
echo "=== Checking single_prefill_with_kv_cache function signature ==="
rg -n -A5 'def single_prefill_with_kv_cache' --type=py
echo ""
echo "=== Checking get_single_prefill_module calls ==="
rg -n -B3 -A10 'get_single_prefill_module' --type=py
echo ""
echo "=== Checking for sinks parameter in prefill.py ==="
rg -n -C5 'sinks' python/flashinfer/prefill.py 2>/dev/null || rg -n -C5 'sinks' flashinfer/prefill.pyLength of output: 10216
🏁 Script executed:
#!/bin/bash
# Get full function signatures for single_prefill_with_kv_cache functions
echo "=== Full signature of single_prefill_with_kv_cache at line 879 ==="
sed -n '879,920p' flashinfer/prefill.py
echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 914 ==="
sed -n '914,960p' flashinfer/prefill.py
echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 939 ==="
sed -n '939,985p' flashinfer/prefill.py
echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 963 ==="
sed -n '963,1010p' flashinfer/prefill.py
echo ""
echo "=== Checking how module is called in single prefill (around line 1164) ==="
sed -n '1160,1210p' flashinfer/prefill.pyLength of output: 671
🏁 Script executed:
#!/bin/bash
# Get full function signatures for single_prefill_with_kv_cache functions
echo "=== Full signature of single_prefill_with_kv_cache at line 879 ==="
sed -n '879,920p' flashinfer/prefill.py
echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 914 ==="
sed -n '914,960p' flashinfer/prefill.py
echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 939 ==="
sed -n '939,985p' flashinfer/prefill.py
echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 963 ==="
sed -n '963,1010p' flashinfer/prefill.py
echo ""
echo "=== Checking how module is called in single prefill (around line 1164) ==="
sed -n '1160,1210p' flashinfer/prefill.py
echo ""
echo "=== Checking run_single_prefill function signature and how it calls module ==="
sed -n '268,310p' flashinfer/prefill.pyLength of output: 9560
🏁 Script executed:
#!/bin/bash
# Check the complete module.run() call in single_prefill_with_kv_cache implementation
echo "=== Complete module.run() call (lines 1176-1195) ==="
sed -n '1176,1195p' flashinfer/prefill.py
echo ""
echo "=== Verify run_single_prefill accepts maybe_s_aux (around line 278) ==="
sed -n '278,290p' flashinfer/prefill.py
echo ""
echo "=== Check if run_single_prefill actually uses maybe_s_aux in fa2 backend ==="
sed -n '310,350p' flashinfer/prefill.pyLength of output: 2182
Add maybe_s_aux parameter to single prefill Python wrapper and pass it to module.run().
The single_prefill_with_kv_cache function does not accept or pass the maybe_s_aux parameter that the underlying generated C++ module expects. The custom operation run_single_prefill (line 278) accepts maybe_s_aux and passes it to the fa2 backend (line 329), but the Python wrapper never provides it. This creates a parameter mismatch.
Update single_prefill_with_kv_cache:
- Add
maybe_s_aux: Optional[torch.Tensor] = Noneparameter (matching batch prefill pattern) - Pass it to
module.run()aftermaybe_alibi_slopesand beforelogits_soft_cap(position 11)
🤖 Prompt for AI Agents
In flashinfer/jit/attention/modules.py around lines 519 to 524, the Python
wrapper single_prefill_with_kv_cache does not accept or forward the maybe_s_aux
tensor that the generated C++ module expects; add a parameter maybe_s_aux:
Optional[torch.Tensor] = None to the function signature (matching the batch
prefill pattern) and include that variable in the module.run(...) call argument
list immediately after maybe_alibi_slopes and before logits_soft_cap (i.e., as
the 11th positional argument) so the wrapper matches the underlying
run_single_prefill/fa2 backend usage.
This PR adds sink support to standard flashinfer decode
Summary by CodeRabbit
New Features
Tests