-
Notifications
You must be signed in to change notification settings - Fork 333
[BugFix] Use BufferRegion in tl.cumsum to infer buffer shape #1321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughNormalize Buffers/BufferRegions and add region-based access_ptr construction; store src/dst BufferRegion on CumSum/Reduce ops; switch CumSum lowering to derive tvm_access_ptr from regions (1D/2D handling); add cumsum+view test and an ASTPrinter export with minor engine debug logging. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as TileLang Python
participant Cpp as C++ Lowering (src/op/reduce.cc)
participant IR as TVM IR
Note over Py,Cpp: User invokes cumsum(view or buffer)
Py->>Cpp: pass src/dst as BufferRegion + dim
Cpp->>Cpp: Normalize input (ConvertBufferToBufferRegion / NormalizeToBufferRegion)
Cpp->>Cpp: store srcRegion_/dstRegion_ on OpNode
Cpp->>Cpp: MakeAccessPtrFromRegion(srcRegion, rw_mask)
Cpp->>Cpp: MakeAccessPtrFromRegion(dstRegion, rw_mask)
Cpp->>IR: Emit loads/stores using computed access_ptr (1D / 2D branches)
IR-->>Py: Lowered PrimFunc returned
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧠 Learnings (1)📚 Learning: 2025-11-14T07:56:11.098ZApplied to files:
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
🔇 Additional comments (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes issue #1001 where tl.cumsum incorrectly infers buffer shapes when using vmap to find buffers. The fix replaces the use of access_ptr with buffer_to_tile_region to properly handle BufferRegion-based shape inference.
Key Changes
- Modified
cumsumandcumsum_fragmentin Python to usebuffer_to_tile_regioninstead ofaccess_ptr - Updated C++
CumSumOpNodeto store and useBufferRegionobjects (srcRegion_,dstRegion_) for accurate shape tracking - Added test case to verify the fix for view-based cumsum operations
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tilelang/language/reduce.py | Updated cumsum operations to use buffer_to_tile_region for proper region tracking |
| tilelang/engine/phase.py | Added debug print statements (should be removed before merge) |
| tilelang/analysis/ast_printer.py | New debugging utility for printing AST structure |
| tilelang/analysis/init.py | Exported ASTPrinter for module-level access |
| src/op/reduce.h | Added BufferRegion fields to CumSumOpNode for region-aware shape inference |
| src/op/reduce.cc | Implemented BufferRegion normalization and updated CumSumOp constructor/Lower methods |
| testing/python/issue/test_tilelang_issue_1001.py | Added regression test for issue #1001 |
Comments suppressed due to low confidence (3)
tilelang/engine/phase.py:230
- Debug code should be removed before merging. This debug print statement appears to have been left in during development and should be removed from production code.
return mod
tilelang/language/reduce.py:249
- Call to function buffer_to_tile_region with too few arguments; should be no fewer than 2.
buffer_to_tile_region(cumsum_smem),
tilelang/language/reduce.py:250
- Call to function buffer_to_tile_region with too few arguments; should be no fewer than 2.
buffer_to_tile_region(cumsum_smem),
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/language/reduce.py (1)
242-253:buffer_to_tile_regioncalls are missing requiredaccess_typeargument incumsum_fragmentThe
buffer_to_tile_regionfunction is defined with a requiredaccess_type: strparameter (utils.py:29), but both calls at lines 248–249 incumsum_fragmentomit it. This will raise a TypeError at runtime.The suggested fix is correct—use
"r"for the read operation and"w"for the write:- tir.call_intrin( - "handle", - tir.op.Op.get("tl.cumsum"), - buffer_to_tile_region(cumsum_smem), - buffer_to_tile_region(cumsum_smem), - dim, - reverse, - ) + tir.call_intrin( + "handle", + tir.op.Op.get("tl.cumsum"), + buffer_to_tile_region(cumsum_smem, "r"), + buffer_to_tile_region(cumsum_smem, "w"), + dim, + reverse, + )Existing tests (
test_cumsum_fragmentandtest_cumsum_fragment_1d) should validate this path after the fix.tilelang/engine/phase.py (1)
70-82: Remove unconditional debug output before mergingBoth
tilelang.analysis.ASTPrinter()(mod)at line 78 andprint(mod)at line 230 execute unconditionally on every compilation, dumping IR to stdout and causing noisy, repeated output in normal use.The codebase already has established debug flag patterns via
PassConfigKeyandpass_ctx.config(seeTIR_ENABLE_DEBUG,TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONSinpass_config.py). Gate both statements behind a debug flag:
- Line 78 (
PreLowerSemanticCheck): Wrap ASTPrinter call behind a debug config check (may require addingpass_ctxparameter)- Line 230 (
OptimizeForTarget): Wrapprint(mod)behindpass_ctx.config.get(..., False)check (consistent with existing pattern on lines 55, 67, etc.)Alternatively, remove entirely before merge if these were temporary debugging aids.
🧹 Nitpick comments (4)
testing/python/issue/test_tilelang_issue_1001.py (1)
1-33: Good focused regression test for the cumsum+view layout inference bugThe test captures the exact failing pattern (shared 1D buffer,
T.view(..., (1, hidden)),T.cumsum(..., dim=1)with the relevant pass configs) and will fail if layout inference regresses again.If you want extra coverage, you could optionally also compare
x.cumsum(dim=1)against the kernel output, but for issue #1001 reproducing successful compilation/execution is already valuable.src/op/reduce.cc (1)
61-70: Minor duplication in Buffer→Region conversion fortvm_access_ptrcaseIn
NormalizeToBufferRegion, thebuiltin::tvm_access_ptrbranch:Var var = Downcast<Var>(call->args[1]); Buffer buf = vmap[var]; Array<Range> ranges; for (PrimExpr extent : buf->shape) { ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); } return BufferRegion(buf, ranges);duplicates the logic in
ConvertBufferToBufferRegion, which is currently unused:static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) { Array<Range> ranges; for (PrimExpr extent : buf->shape) { ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); } return BufferRegion(buf, ranges); }You could either:
- Replace the inlined construction with a call to
ConvertBufferToBufferRegion(buf), or- Drop the helper entirely if you don’t plan to use it elsewhere.
This is purely a cleanliness/readability issue; behavior is already correct.
Also applies to: 569-577
tilelang/analysis/ast_printer.py (2)
2-3: Clean up the import formatting.The
PrimFuncimport has unconventional formatting with a trailing comma and closing parenthesis on separate lines. Consider simplifying to a single line for consistency with other imports:-from tvm.tir import ( - PrimFunc,) +from tvm.tir import PrimFunc
20-22: Consider prefixing unused parameters with underscore.The
modandctxparameters are required by theprim_func_passsignature but are not used inpass_fn. To indicate they are intentionally unused, consider prefixing them with underscores:- def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc: + def pass_fn(func: PrimFunc, _mod, _ctx) -> PrimFunc: new_body = ir_transform(func.body, pre_visit, None) return func.with_body(new_body)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
src/op/reduce.cc(6 hunks)src/op/reduce.h(2 hunks)testing/python/issue/test_tilelang_issue_1001.py(1 hunks)tilelang/analysis/__init__.py(1 hunks)tilelang/analysis/ast_printer.py(1 hunks)tilelang/engine/phase.py(2 hunks)tilelang/language/reduce.py(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/issue/test_tilelang_issue_1001.py
🧬 Code graph analysis (5)
tilelang/language/reduce.py (1)
tilelang/language/utils.py (1)
buffer_to_tile_region(29-41)
tilelang/analysis/ast_printer.py (1)
tilelang/language/v2/builder.py (1)
PrimFunc(517-526)
testing/python/issue/test_tilelang_issue_1001.py (1)
tilelang/transform/pass_config.py (1)
PassConfigKey(6-144)
src/op/reduce.cc (1)
src/op/gemm.cc (5)
strides(113-113)NormalizeToBufferRegion(53-100)NormalizeToBufferRegion(53-54)MakeAccessPtrFromRegion(106-136)MakeAccessPtrFromRegion(106-107)
tilelang/analysis/__init__.py (1)
tilelang/analysis/ast_printer.py (1)
ASTPrinter(8-24)
🪛 Ruff (0.14.5)
tilelang/analysis/ast_printer.py
20-20: Unused function argument: mod
(ARG001)
20-20: Unused function argument: ctx
(ARG001)
tilelang/analysis/__init__.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Agent
- GitHub Check: CodeQL analysis (python)
- GitHub Check: Quick Lint
🔇 Additional comments (4)
tilelang/language/reduce.py (1)
289-307: Region-basedtl.cumsumcall from Python looks consistent with new loweringSwitching
cumsumto emit regions instead of rawaccess_ptr:buffer_to_tile_region(src, "r"), buffer_to_tile_region(dst, "w"),aligns with the new
NormalizeToBufferRegion/srcRegion_/dstRegion_flow and should correctly preserve view shapes (includingT.view(smem, (1, hidden))).src/op/reduce.h (1)
84-104: Region fields and reflection forReduceOpNode/CumSumOpNodelook correctAdding
srcRegion_/dstRegion_and exposing them assrcRegion/dstRegionvia reflection is consistent with the new region-based lowering inreduce.cc. Clone behavior is preserved via copy-construction.Also applies to: 135-151
src/op/reduce.cc (1)
579-599:CumSumOpconstructor: region plumbing anddimbound check look goodSwitching the constructor to:
- Normalize
args[0]/args[1]tosrcRegion_/dstRegion_viaNormalizeToBufferRegion, and- Derive
src/dstbuffers from those regions,is exactly what’s needed to preserve view shapes (e.g.,
T.view(smem, (1, hidden))) instead of relying on the vmap-only buffer lookup. The added upper-bound check ondimvs.src->shape.size()also gives an earlier, clearer error.Given the Python side normalizes negative dims before constructing
CumSumOp, this check is sufficient.tilelang/analysis/__init__.py (1)
3-3: LGTM! Static analysis hint is a false positive.The
noqa: F401directive is intentional for this re-export pattern, makingASTPrinteravailable as part of the public API. The Ruff hint can be safely ignored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/op/reduce.cc (1)
594-595: Remove commented-out code.These commented lines represent the old implementation and should be removed to reduce code clutter.
Apply this diff:
- // node->src = vmap[GetVarFromAccessPtr(args[0])]; - // node->dst = vmap[GetVarFromAccessPtr(args[1])];
🧹 Nitpick comments (2)
testing/python/language/test_tilelang_language_atomic_add.py (1)
263-263: Inconsistent decorator usage.This is the only decorator in the file using
@tilelang.jit()with parentheses, while all other instances (lines 5, 40, 77, 112, 147, 173, 209, 298, 331) use@tilelang.jitwithout parentheses.Apply this diff to maintain consistency:
-@tilelang.jit() +@tilelang.jitsrc/op/reduce.cc (1)
94-111: Clarify comment for 2D-only branch.The comment "Compute row-major strides for ndim >= 2" suggests support for 3+ dimensions, but the ICHECK at line 85 restricts this function to exactly 1D or 2D buffers. Since CumSum only supports 1D and 2D (as stated at line 639), this restriction is appropriate.
Apply this diff to clarify the comment:
- // Compute row-major strides for ndim >= 2 + // Compute row-major strides for 2D case
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/op/reduce.cc(6 hunks)testing/python/language/test_tilelang_language_atomic_add.py(1 hunks)tilelang/analysis/ast_printer.py(1 hunks)tilelang/language/reduce.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tilelang/language/reduce.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_atomic_add.py
🧬 Code graph analysis (2)
src/op/reduce.cc (1)
src/op/gemm.cc (5)
strides(113-113)NormalizeToBufferRegion(53-100)NormalizeToBufferRegion(53-54)MakeAccessPtrFromRegion(106-136)MakeAccessPtrFromRegion(106-107)
tilelang/analysis/ast_printer.py (1)
tilelang/language/v2/builder.py (1)
PrimFunc(517-526)
🪛 Ruff (0.14.5)
tilelang/analysis/ast_printer.py
19-19: Unused function argument: mod
(ARG001)
19-19: Unused function argument: ctx
(ARG001)
🔇 Additional comments (5)
tilelang/analysis/ast_printer.py (1)
19-21: LGTM!The unused
modandctxparameters flagged by static analysis are part of the required callback signature forprim_func_passand cannot be removed. The implementation correctly applies the pre-order visitor to print statement types for debugging.src/op/reduce.cc (4)
61-70: LGTM!The new handling for
builtin.tvm_access_ptrcorrectly extracts the buffer variable and constructs a full BufferRegion using the buffer's shape. This enables proper region-based normalization for tvm_access_ptr calls.
576-584: LGTM!The helper correctly converts a Buffer to a BufferRegion by creating full ranges [0, extent) for each dimension. This normalization is essential for uniform region-based handling.
596-606: LGTM!The region-based normalization correctly handles various input forms (BufferRegion, BufferLoad, tvm_access_ptr) and the new dimension validation (lines 602-605) prevents invalid cumsum operations with out-of-range dimensions.
623-625: Core fix for cumsum + view shape inference.This change correctly builds access pointers from
BufferRegioninstead of raw arguments, ensuring proper shape inference whenT.cumsumis applied to viewed buffers. The region-based approach accurately captures the buffer's dimensions and offsets, addressing issue #1001.
Fix #1001. This is because CumSum uses
vmapto find the buffer of corresponding var, which leads to a wrong shape inference.Summary by CodeRabbit
New Features
Bug Fixes
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.