-
Notifications
You must be signed in to change notification settings - Fork 334
[TileOp] Implememt CumSum1D
#978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdds 1D CumSum support end-to-end: new CUDA template CumSum1D, lowers CumSumOp to call 1D or 2D kernels based on input rank, errors for >2D, updates tests to cover 1D smem/fragment variants, and expands docstring examples. No other public APIs changed. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant IR as CumSumOpNode::Lower
participant Extern as Extern Call Builder
participant TL1D as tl::CumSum1D
participant TL2D as tl::CumSum2D
participant Err as Fatal Error
IR->>IR: Inspect input rank (ndim)
alt ndim == 1
IR->>Extern: Build args (dst, src, N, threads, reverse)
Extern->>TL1D: run(dst, src, N)
Note over TL1D: Inclusive scan (forward or reverse)
else ndim == 2
IR->>Extern: Build args (dst, src, M, N, threads, dim, reverse)
Extern->>TL2D: run(dst, src, M, N)
Note over TL2D: Inclusive scan along chosen dim
else (>2)
IR-->>Err: TL_FAIL("CumSum supports 1D/2D only")
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
src/op/reduce.cc (1)
423-441: The 1D/2D branching logic is correct.The implementation properly distinguishes between 1D and 2D cases, with appropriate dimension checks and error handling for unsupported dimensionalities. The argument construction and template instantiation for both
CumSum1DandCumSum2Dare correct.Consider using string access modes for consistency.
Lines 430 and 435 use integer access modes (1, 3) for
access_ptr, while other parts of the codebase (e.g., lines 37-38 inReduceOpNode) use string modes ("r", "w"). Using strings would improve readability and maintain consistency.Apply this diff for better consistency:
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), src->shape[0]}; + args = {StringImm(ss.str()), src.access_ptr("r"), dst.access_ptr("w"), + src->shape[0]}; } else if (ndim == 2) { ss << "tl::CumSum2D<" << threads << ", " << dim << ", " << (reverse ? "true" : "false") << ">::run"; args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), src->shape[0], src->shape[1]}; + args = {StringImm(ss.str()), src.access_ptr("r"), dst.access_ptr("w"), + src->shape[0], src->shape[1]};testing/python/language/test_tilelang_language_cumsum.py (1)
74-141: Well-structured 1D cumsum tests with correct reference implementation.The 1D test functions properly mirror the 2D test structure and correctly validate block-local cumulative sum behavior:
- Test kernels (
cumsum_smem_test_1d,cumsum_fragment_test_1d): Correctly implement 1D cumsum in both shared memory and fragment scopes.- Reference implementation (lines 123-136): Accurately models block-local cumsum, processing each block independently and handling reverse mode by flipping before/after the cumsum operation.
- Test coverage: Includes both forward and reverse modes, matching the capabilities of the underlying CUDA kernel.
Extract exception message to reduce line length.
Line 118 has a long error message embedded in the raise statement. Consider extracting it to a variable for better readability.
Apply this diff:
- raise ValueError(f"Unknown scope {scope}") + msg = f"Unknown scope {scope}. Expected 'smem' or 'fragment'." + raise ValueError(msg)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/op/reduce.cc(2 hunks)src/tl_templates/cuda/reduce.h(1 hunks)testing/python/language/test_tilelang_language_cumsum.py(2 hunks)tilelang/language/reduce.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_cumsum.py (4)
tilelang/language/reduce.py (1)
cumsum(157-207)tilelang/language/allocate.py (2)
alloc_shared(21-36)alloc_fragment(53-64)tilelang/language/copy.py (1)
copy(10-86)tilelang/jit/__init__.py (1)
compile(34-90)
🪛 Ruff (0.13.3)
testing/python/language/test_tilelang_language_cumsum.py
118-118: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-metal
- GitHub Check: format-check
🔇 Additional comments (3)
tilelang/language/reduce.py (1)
163-185: LGTM! Clear and helpful documentation examples.The added examples effectively demonstrate both 1D and 2D cumsum usage patterns, showing proper buffer allocation, copying, and the cumsum operation with different configurations (including reverse mode). These examples will help users understand the API.
src/tl_templates/cuda/reduce.h (1)
71-137: Excellent implementation of 1D cumulative sum with warp-level primitives.The
CumSum1Dtemplate correctly implements both forward and reverse cumulative sum using efficient warp shuffle operations:
- Forward mode: Uses
__shfl_up_syncfor inclusive prefix sum within each segment, correctly propagates carry from left to right across segments.- Reverse mode: Uses
__shfl_down_syncfor reverse cumulative sum within each segment, correctly propagates carry from right to left across segments.- Edge cases: Properly handles N ≤ 0 (early return) and partial segments (boundary checks at lines 93, 104, 116, 127).
- Design: Restricting to first SEG threads (lines 84-85) is appropriate for warp-level operations and sequential carry dependency.
testing/python/language/test_tilelang_language_cumsum.py (1)
164-171: LGTM! Test entry points provide good coverage.The test functions exercise both shared memory and fragment scopes with forward and reverse modes, providing comprehensive coverage of the 1D cumsum functionality.
|
Local test pass, merged. |
* support cumsum-1d * cumsum 1d support
This pull request adds support for 1D cumulative sum (CumSum) operations in both the CUDA backend and the TileLang language, including corresponding tests for shared memory and fragment scopes. The changes include a new implementation for 1D CumSum on CUDA, updates to the lowering logic to handle 1D buffers, and new Python tests to verify correctness.
1D CumSum Support
CumSum1Dinreduce.hfor efficient 1D cumulative sum operations, supporting both forward and reverse directions.reduce.ccto handle 1D CumSum operations, including error checking for dimensionality and proper argument passing to the CUDA kernel.Testing Enhancements
cumsum_smem_test_1d,cumsum_fragment_test_1d, andrun_cumsum_1d) to test 1D CumSum behavior in both shared memory and fragment scopes. These tests include reference implementations and correctness checks against the CUDA kernel output.test_cumsum_smem_1d,test_cumsum_fragment_1d) to the test suite to ensure comprehensive coverage of 1D CumSum operations, including reverse mode.Summary by CodeRabbit
New Features
Tests
Documentation