Skip to content

Conversation

@Cunxiao2002
Copy link
Contributor

@Cunxiao2002 Cunxiao2002 commented Oct 5, 2025

for default tests:
before:

  • Latency: 0.48080000281333923 ms
  • TFlops: 53.59776128371738 TFlops

after:

  • Latency: 0.30057600140571594 ms
  • TFlops: 85.73473482740243 TFlops

Summary by CodeRabbit

  • Refactor
    • Kernel/grid sizing now uses a precomputed total M-block count for more consistent launch dimensions.
    • Batching and padding logic streamlined while preserving overall behavior.
  • Bug Fixes
    • Fixed padding calculation to prevent unintended extra padding in batched GEMM.
  • Notes
    • No API changes; behavior remains compatible.

@github-actions
Copy link

github-actions bot commented Oct 5, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@Cunxiao2002 Cunxiao2002 changed the title [Example] Fix lint to improve grouped GEMM performance with TMA [Enhancement] Fix lint to improve grouped GEMM performance with TMA Oct 5, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 5, 2025

Walkthrough

Updates the grouped GEMM forward example to compute total_m_blocks from batch_sizes_list, use it as the kernel grid x-dimension, and remove a +1 from the per-batch padding calculation; other functional flow and APIs are unchanged.

Changes

Cohort / File(s) Summary
Grouped GEMM example adjustments
examples/grouped_gemm/example_grouped_gemm_fwd.py
- Remove top-level cache-disabling call removal handled earlier
- Compute total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list) and use it for the kernel grid x-dimension
- Replace previous x-dimension T.ceildiv(batch_sum, block_M) + batch_count with total_m_blocks; y-dimension remains T.ceildiv(N, block_N)
- Change padding calculation to math.ceil(batch_sizes_list[i] / padding_M) * padding_M (removed +1)
- No exported/public API changes; internal launch sizing and padding logic adjusted

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant E as Example Script
  participant C as Config Builder
  participant K as Triton Kernel

  E->>C: prepare batch_sizes_list, block_M, block_N, N
  C->>C: total_m_blocks = Σ ((size + block_M -1) // block_M)
  Note right of C: grid.x = total_m_blocks<br/>grid.y = ceil(N / block_N)
  C->>C: padding_i = ceil(batch_sizes_list[i] / padding_M) * padding_M
  C-->>K: launch(grid=(total_m_blocks, ceil(N/block_N)), ...)
  K->>K: execute grouped GEMM blocks
  K-->>E: return/completion
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

I nibble on blocks and count each row,
summing M-tiles where the data go.
No +1 crumbs, padding neat and trim —
kernels leap, my whiskers brim.
Hooray for tiled hops and a launch so prim! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Title Check ⚠️ Warning The title emphasizes lint fixes whereas the changeset actually updates kernel block calculations, removes a cache‐disabling call, and adjusts padding logic to improve grouped GEMM performance with TMA, so it does not accurately or concisely describe the main modifications. Revise the title to clearly summarize the core changes, for example: “Adjust grouped GEMM kernel dimensions and padding to improve TMA performance.”
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5a733ec and 51e74db.

📒 Files selected for processing (1)
  • examples/grouped_gemm/example_grouped_gemm_fwd.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/grouped_gemm/example_grouped_gemm_fwd.py (1)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (1)
  • ceildiv (87-88)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-amd
🔇 Additional comments (3)
examples/grouped_gemm/example_grouped_gemm_fwd.py (3)

58-58: LGTM! Correct computation of total M-dimension blocks.

The ceiling division formula (size + block_M - 1) // block_M correctly computes the number of blocks needed for each batch, and summing across all batches gives the exact total. This replaces the previous over-allocation that added batch_count extra blocks, which is a key optimization contributing to the performance improvement.


70-70: LGTM! Kernel grid dimension correctly uses total_m_blocks.

Using total_m_blocks as the first grid dimension ensures the kernel launches exactly the right number of blocks needed for the work, eliminating the over-allocation from the previous approach. This aligns with the block assignment logic in lines 77-89.


115-115: LGTM! Fixes padding calculation bug.

Removing the +1 corrects the padding calculation. The previous formula would incorrectly allocate an extra padding_M bytes when batch_sizes_list[i] was a multiple of padding_M (e.g., 64 would pad to 128 instead of 64). This fix reduces unnecessary padding, contributing to the performance improvement.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
examples/grouped_gemm/example_grouped_gemm_fwd.py (2)

7-7: Remove commented-out code.

This commented line should either be removed entirely or uncommented if the cache needs to be disabled for development/debugging purposes.

Apply this diff if you choose to remove it:

-# tilelang.disable_cache()
-

60-60: LGTM! Precomputing total_m_blocks improves precision and performance.

The calculation correctly computes the exact number of M blocks needed across all batches using (size + block_M - 1) // block_M, which is more efficient than the previous approach that would overallocate. This directly contributes to the performance improvement shown in the PR (53.6 → 85.7 TFlops).

Optionally, for consistency with the codebase, consider extracting the ceildiv formula into a helper function:

def ceildiv(a, b):
    return (a + b - 1) // b

total_m_blocks = sum(ceildiv(size, block_M) for size in batch_sizes_list)

Based on learnings from examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py where a similar helper exists.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b31de0c and 5a733ec.

📒 Files selected for processing (1)
  • examples/grouped_gemm/example_grouped_gemm_fwd.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/grouped_gemm/example_grouped_gemm_fwd.py (2)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (1)
  • ceildiv (87-88)
tilelang/language/kernel.py (1)
  • threads (195-199)
🔇 Additional comments (2)
examples/grouped_gemm/example_grouped_gemm_fwd.py (2)

72-72: LGTM! Kernel grid sizing correctly uses precomputed total_m_blocks.

The kernel launch now uses the exact number of M blocks needed (total_m_blocks) instead of overallocating. This is consistent with the precomputation on line 60 and the corrected padding calculation on line 117. The kernel logic properly handles batch boundaries and partial blocks.


117-117: LGTM! Removing +1 fixes unnecessary padding at exact block boundaries.

The previous formula math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_M caused overallocation when batch sizes were exactly divisible by padding_M. For example, with size=64 and padding_M=64, the old formula would pad to 128 instead of 64. Removing the +1 corrects this behavior and contributes to the performance improvement.

@LeiWang1999
Copy link
Member

We're good to go if we removed useless comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants