-
Notifications
You must be signed in to change notification settings - Fork 331
[Example] Specify a fixed commit for the flash-linear-attention repository and optimize nsa examples #913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ample - Updated the requirements.txt to specify a fixed commit for the flash-linear-attention repository. - Refactored import paths in benchmark_nsa_fwd.py for better organization. - Added a new function to generate configurations for autotuning. - Modified the tilelang_sparse_attention function to accept parameters for block size, number of stages, and threads, enhancing flexibility. - Changed allocation of shared memory for accumulators to optimize performance.
WalkthroughAdds a new ignore entry for .claude in .gitignore, introduces autotuning support and a new get_configs() in the NSA forward benchmark with updated kernel signature and import path, and pins flash-linear-attention to a specific commit in requirements. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Benchmark as benchmark_nsa_fwd.py
participant TileLang as TileLang Autotune
participant Kernel as SparseAttention Kernel
participant FLA as flash-linear-attention
User->>Benchmark: Run benchmark_nsa(...)
Benchmark->>Benchmark: get_configs()
Benchmark->>TileLang: decorate tilelang_sparse_attention(configs)
TileLang->>Kernel: compile variants (block_T, num_stages, threads)
Note over TileLang,Kernel: Autotune selects best-performing variant
Benchmark->>Kernel: invoke compiled kernel
Kernel->>FLA: use utils (prepare_token_indices)
Kernel-->>Benchmark: results
Benchmark-->>User: benchmark output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🧪 Early access (Sonnet 4.5): enabledWe are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience. Note:
Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (3)
442-452: Consider the autotune search space size.The configuration generates 75 combinations (3 × 5 × 5). This could result in lengthy autotune times, especially with the default warmup (25) and rep (100) settings from the autotune decorator.
Consider either:
- Reducing the search space for faster iteration during development
- Documenting expected autotune duration
- Adding a smaller "quick" configuration set for testing
Example of a reduced search space:
def get_configs(): import itertools iter_params = dict( - block_T=[128, 256, 512], - num_stages=[0, 1, 2, 4, 5], - threads=[32, 64, 128, 256, 512], + block_T=[128, 256], + num_stages=[1, 2, 4], + threads=[128, 256], ) return [{ k: v for k, v in zip(iter_params, values) } for values in itertools.product(*iter_params.values())]
454-454: Remove trailing comma.Minor style nitpick: the trailing comma after
get_configs()is unnecessary.-@tilelang.autotune(configs=get_configs(),) +@tilelang.autotune(configs=get_configs())
515-515: Consider whether Q/K/V shared buffers need layout annotations.Only
O_sharedhas a swizzled layout annotation. Depending on the access patterns,Q_shared,K_shared, andV_sharedmight also benefit from layout annotations for better memory performance.If the Q/K/V buffers also have strided access patterns that could benefit from swizzling, consider adding layout annotations:
-T.annotate_layout({O_shared: tilelang.layout.make_swizzled_layout(O_shared)}) +T.annotate_layout({ + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + O_shared: tilelang.layout.make_swizzled_layout(O_shared) +})
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
.gitignore(1 hunks)examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py(6 hunks)examples/deepseek_nsa/requirements.txt(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (4)
tilelang/autotuner/tuner.py (1)
autotune(701-794)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/math/__init__.py (1)
next_power_of_2(1-2)tilelang/language/__init__.py (1)
annotate_layout(106-144)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: format-check
🔇 Additional comments (7)
.gitignore (1)
96-97: LGTM!The
.claudeignore pattern is properly formatted and consistent with other directory-level ignores in the file.examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (5)
464-467: LGTM! Autotune parameters properly integrated.The new tuning parameters (
block_T,num_stages,threads) are well-integrated with sensible defaults, allowing the function to work both with and without autotuning.
481-481: LGTM! Proper validation of block_T.Clamping
block_Tto the next power of 2 ofdimprevents invalid configurations and ensures memory alignment.
507-507: Verify the memory allocation change is intentional.The allocation of
acc_s_casthas changed from fragment (register) memory to shared memory. This could impact performance, as shared memory has higher latency than registers.Please confirm:
- Is this change necessary for correctness with the new autotune configurations?
- Have you measured the performance impact of this change?
If this was done to support larger block sizes or fix a compilation issue, please document the reasoning.
611-621: LGTM! Clearer variable naming.Renaming from
programtokernelbetter reflects that this is a compiled, executable kernel rather than a program representation.
13-13: Cannot locateprepare_token_indices; please verify import
The search didn’t find any definition or re-export ofprepare_token_indicesunderfla/ops/utils.py. Manually confirm thatfla.ops.utils.prepare_token_indicesis present in the pinned flash-linear-attention (commit c3bd565).examples/deepseek_nsa/requirements.txt (1)
1-1: Pinned commit verified. The specified commit exists in fla-org/flash-linear-attention; no further action required.
…itory and optimize nsa examples (tile-ai#913) - Updated the requirements.txt to specify a fixed commit for the flash-linear-attention repository. - Refactored import paths in benchmark_nsa_fwd.py for better organization. - Added a new function to generate configurations for autotuning. - Modified the tilelang_sparse_attention function to accept parameters for block size, number of stages, and threads, enhancing flexibility. - Changed allocation of shared memory for accumulators to optimize performance.
Summary by CodeRabbit
New Features
Chores
Refactor