-
Notifications
You must be signed in to change notification settings - Fork 332
[Enhancement] Fix lint to improve grouped GEMM performance with TMA #938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughUpdates the grouped GEMM forward example to compute total_m_blocks from batch_sizes_list, use it as the kernel grid x-dimension, and remove a +1 from the per-batch padding calculation; other functional flow and APIs are unchanged. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant E as Example Script
participant C as Config Builder
participant K as Triton Kernel
E->>C: prepare batch_sizes_list, block_M, block_N, N
C->>C: total_m_blocks = Σ ((size + block_M -1) // block_M)
Note right of C: grid.x = total_m_blocks<br/>grid.y = ceil(N / block_N)
C->>C: padding_i = ceil(batch_sizes_list[i] / padding_M) * padding_M
C-->>K: launch(grid=(total_m_blocks, ceil(N/block_N)), ...)
K->>K: execute grouped GEMM blocks
K-->>E: return/completion
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code graph analysis (1)examples/grouped_gemm/example_grouped_gemm_fwd.py (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (3)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
examples/grouped_gemm/example_grouped_gemm_fwd.py (2)
7-7: Remove commented-out code.This commented line should either be removed entirely or uncommented if the cache needs to be disabled for development/debugging purposes.
Apply this diff if you choose to remove it:
-# tilelang.disable_cache() -
60-60: LGTM! Precomputing total_m_blocks improves precision and performance.The calculation correctly computes the exact number of M blocks needed across all batches using
(size + block_M - 1) // block_M, which is more efficient than the previous approach that would overallocate. This directly contributes to the performance improvement shown in the PR (53.6 → 85.7 TFlops).Optionally, for consistency with the codebase, consider extracting the ceildiv formula into a helper function:
def ceildiv(a, b): return (a + b - 1) // b total_m_blocks = sum(ceildiv(size, block_M) for size in batch_sizes_list)Based on learnings from examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py where a similar helper exists.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/grouped_gemm/example_grouped_gemm_fwd.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/grouped_gemm/example_grouped_gemm_fwd.py (2)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (1)
ceildiv(87-88)tilelang/language/kernel.py (1)
threads(195-199)
🔇 Additional comments (2)
examples/grouped_gemm/example_grouped_gemm_fwd.py (2)
72-72: LGTM! Kernel grid sizing correctly uses precomputed total_m_blocks.The kernel launch now uses the exact number of M blocks needed (
total_m_blocks) instead of overallocating. This is consistent with the precomputation on line 60 and the corrected padding calculation on line 117. The kernel logic properly handles batch boundaries and partial blocks.
117-117: LGTM! Removing +1 fixes unnecessary padding at exact block boundaries.The previous formula
math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_Mcaused overallocation when batch sizes were exactly divisible bypadding_M. For example, withsize=64andpadding_M=64, the old formula would pad to 128 instead of 64. Removing the+1corrects this behavior and contributes to the performance improvement.
|
We're good to go if we removed useless comments. |
…ile-ai#938) * [Example] Fix lint to improve grouped GEMM performance with TMA * fix lint
for default tests:
before:
after:
Summary by CodeRabbit