Skip to content

Conversation

@SiriusNEO
Copy link
Collaborator

@SiriusNEO SiriusNEO commented Nov 24, 2025

Fix #1001. This is because CumSum uses vmap to find the buffer of corresponding var, which leads to a wrong shape inference.

Summary by CodeRabbit

  • New Features

    • Added an AST printer debugging tool and a module-level AST dump at the start of semantic checks.
  • Bug Fixes

    • Improved error messaging and dimension validation for cumulative-sum operations.
  • Tests

    • Added an integration test for cumsum view layout inference and adjusted a JIT test decorator configuration.
  • Chores

    • Migrated buffer handling to region-based inputs and aligned pointer sourcing for reduce/cumsum paths.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 24, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Normalize Buffers/BufferRegions and add region-based access_ptr construction; store src/dst BufferRegion on CumSum/Reduce ops; switch CumSum lowering to derive tvm_access_ptr from regions (1D/2D handling); add cumsum+view test and an ASTPrinter export with minor engine debug logging.

Changes

Cohort / File(s) Summary
C++ reduce implementation
src/op/reduce.cc
Add MakeAccessPtrFromRegion(const BufferRegion&, int) and ConvertBufferToBufferRegion(const Buffer&); extend NormalizeToBufferRegion to handle builtin.tvm_access_ptr(...); route CumSum lowering to construct access_ptrs from BufferRegion (compute offset/extent from region, 1D/2D branches), adjust validation/error messages and minor formatting.
C++ reduce declarations
src/op/reduce.h
Add srcRegion_ and dstRegion_ BufferRegion fields to CumSumOpNode and ReduceOpNode, expose via reflection.
TileLang Python reduce layer
tilelang/language/reduce.py
Replace direct access_ptr calls with buffer_to_tile_region wrappers in cumsum_fragment and cumsum, passing region-based src/dst into lowering.
TileLang analysis / AST printing
tilelang/analysis/ast_printer.py, tilelang/analysis/__init__.py
New ASTPrinter() prim_func_pass that pre-order visits and prints statement types; exported from tilelang.analysis.
Engine phase debug
tilelang/engine/phase.py
Add debug AST/module print at start of PreLowerSemanticCheck.
Tests — new issue test
testing/python/issue/test_tilelang_issue_1001.py
New test exercising cumsum over a view into shared memory to validate layout inference and region-based lowering.
Tests — minor decorator change
testing/python/language/test_tilelang_language_atomic_add.py
Simplified @tilelang.jit decorator by removing custom debug_root_path argument.

Sequence Diagram(s)

sequenceDiagram
    participant Py as TileLang Python
    participant Cpp as C++ Lowering (src/op/reduce.cc)
    participant IR as TVM IR

    Note over Py,Cpp: User invokes cumsum(view or buffer)
    Py->>Cpp: pass src/dst as BufferRegion + dim
    Cpp->>Cpp: Normalize input (ConvertBufferToBufferRegion / NormalizeToBufferRegion)
    Cpp->>Cpp: store srcRegion_/dstRegion_ on OpNode
    Cpp->>Cpp: MakeAccessPtrFromRegion(srcRegion, rw_mask)
    Cpp->>Cpp: MakeAccessPtrFromRegion(dstRegion, rw_mask)
    Cpp->>IR: Emit loads/stores using computed access_ptr (1D / 2D branches)
    IR-->>Py: Lowered PrimFunc returned
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Pay attention to offset/extent arithmetic in MakeAccessPtrFromRegion across dimensionalities.
  • Verify NormalizeToBufferRegion handling of builtin.tvm_access_ptr(...) and that region/shape mapping for views and shared-memory cases is correct.
  • Check reflection/initialization of new srcRegion_/dstRegion_ fields and any impacts on serialization or passes.

Possibly related PRs

Poem

🐰 I hopped through buffers, regions mapped my trail,

Offsets snug in lanes where views would sometimes fail.
CumSum now follows regions, pointers born precise,
Shared strides and extents lining up like rice.
A small hop, a gentle fix — the kernel sings, concise.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 45.45% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: using BufferRegion in tl.cumsum to fix buffer shape inference, which directly addresses the linked issue.
Linked Issues check ✅ Passed The PR implements region-based buffer handling for CumSum to fix shape inference errors, directly addressing the core objective of issue #1001.
Out of Scope Changes check ✅ Passed Minor changes to ASTPrinter and ReduceOp are supporting modifications; the primary focus on CumSum region-based handling aligns with the linked issue objective.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8f58e93 and c3e87a0.

📒 Files selected for processing (1)
  • testing/python/language/test_tilelang_language_atomic_add.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_atomic_add.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (1)
testing/python/language/test_tilelang_language_atomic_add.py (1)

263-263: LGTM: Consistency improvement.

Removing the debug_root_path parameter aligns this decorator with all other @tilelang.jit usages in the file, improving consistency.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copilot finished reviewing on behalf of SiriusNEO November 24, 2025 05:21
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes issue #1001 where tl.cumsum incorrectly infers buffer shapes when using vmap to find buffers. The fix replaces the use of access_ptr with buffer_to_tile_region to properly handle BufferRegion-based shape inference.

Key Changes

  • Modified cumsum and cumsum_fragment in Python to use buffer_to_tile_region instead of access_ptr
  • Updated C++ CumSumOpNode to store and use BufferRegion objects (srcRegion_, dstRegion_) for accurate shape tracking
  • Added test case to verify the fix for view-based cumsum operations

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tilelang/language/reduce.py Updated cumsum operations to use buffer_to_tile_region for proper region tracking
tilelang/engine/phase.py Added debug print statements (should be removed before merge)
tilelang/analysis/ast_printer.py New debugging utility for printing AST structure
tilelang/analysis/init.py Exported ASTPrinter for module-level access
src/op/reduce.h Added BufferRegion fields to CumSumOpNode for region-aware shape inference
src/op/reduce.cc Implemented BufferRegion normalization and updated CumSumOp constructor/Lower methods
testing/python/issue/test_tilelang_issue_1001.py Added regression test for issue #1001
Comments suppressed due to low confidence (3)

tilelang/engine/phase.py:230

  • Debug code should be removed before merging. This debug print statement appears to have been left in during development and should be removed from production code.
    return mod

tilelang/language/reduce.py:249

        buffer_to_tile_region(cumsum_smem),

tilelang/language/reduce.py:250

        buffer_to_tile_region(cumsum_smem),

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tilelang/language/reduce.py (1)

242-253: buffer_to_tile_region calls are missing required access_type argument in cumsum_fragment

The buffer_to_tile_region function is defined with a required access_type: str parameter (utils.py:29), but both calls at lines 248–249 in cumsum_fragment omit it. This will raise a TypeError at runtime.

The suggested fix is correct—use "r" for the read operation and "w" for the write:

-    tir.call_intrin(
-        "handle",
-        tir.op.Op.get("tl.cumsum"),
-        buffer_to_tile_region(cumsum_smem),
-        buffer_to_tile_region(cumsum_smem),
-        dim,
-        reverse,
-    )
+    tir.call_intrin(
+        "handle",
+        tir.op.Op.get("tl.cumsum"),
+        buffer_to_tile_region(cumsum_smem, "r"),
+        buffer_to_tile_region(cumsum_smem, "w"),
+        dim,
+        reverse,
+    )

Existing tests (test_cumsum_fragment and test_cumsum_fragment_1d) should validate this path after the fix.

tilelang/engine/phase.py (1)

70-82: Remove unconditional debug output before merging

Both tilelang.analysis.ASTPrinter()(mod) at line 78 and print(mod) at line 230 execute unconditionally on every compilation, dumping IR to stdout and causing noisy, repeated output in normal use.

The codebase already has established debug flag patterns via PassConfigKey and pass_ctx.config (see TIR_ENABLE_DEBUG, TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS in pass_config.py). Gate both statements behind a debug flag:

  • Line 78 (PreLowerSemanticCheck): Wrap ASTPrinter call behind a debug config check (may require adding pass_ctx parameter)
  • Line 230 (OptimizeForTarget): Wrap print(mod) behind pass_ctx.config.get(..., False) check (consistent with existing pattern on lines 55, 67, etc.)

Alternatively, remove entirely before merge if these were temporary debugging aids.

🧹 Nitpick comments (4)
testing/python/issue/test_tilelang_issue_1001.py (1)

1-33: Good focused regression test for the cumsum+view layout inference bug

The test captures the exact failing pattern (shared 1D buffer, T.view(..., (1, hidden)), T.cumsum(..., dim=1) with the relevant pass configs) and will fail if layout inference regresses again.

If you want extra coverage, you could optionally also compare x.cumsum(dim=1) against the kernel output, but for issue #1001 reproducing successful compilation/execution is already valuable.

src/op/reduce.cc (1)

61-70: Minor duplication in Buffer→Region conversion for tvm_access_ptr case

In NormalizeToBufferRegion, the builtin::tvm_access_ptr branch:

Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap[var];
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
  ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);

duplicates the logic in ConvertBufferToBufferRegion, which is currently unused:

static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) {
  Array<Range> ranges;
  for (PrimExpr extent : buf->shape) {
    ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
  }
  return BufferRegion(buf, ranges);
}

You could either:

  • Replace the inlined construction with a call to ConvertBufferToBufferRegion(buf), or
  • Drop the helper entirely if you don’t plan to use it elsewhere.

This is purely a cleanliness/readability issue; behavior is already correct.

Also applies to: 569-577

tilelang/analysis/ast_printer.py (2)

2-3: Clean up the import formatting.

The PrimFunc import has unconventional formatting with a trailing comma and closing parenthesis on separate lines. Consider simplifying to a single line for consistency with other imports:

-from tvm.tir import (
-    PrimFunc,)
+from tvm.tir import PrimFunc

20-22: Consider prefixing unused parameters with underscore.

The mod and ctx parameters are required by the prim_func_pass signature but are not used in pass_fn. To indicate they are intentionally unused, consider prefixing them with underscores:

-    def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc:
+    def pass_fn(func: PrimFunc, _mod, _ctx) -> PrimFunc:
         new_body = ir_transform(func.body, pre_visit, None)
         return func.with_body(new_body)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9f7bac4 and 2dd5d6e.

📒 Files selected for processing (7)
  • src/op/reduce.cc (6 hunks)
  • src/op/reduce.h (2 hunks)
  • testing/python/issue/test_tilelang_issue_1001.py (1 hunks)
  • tilelang/analysis/__init__.py (1 hunks)
  • tilelang/analysis/ast_printer.py (1 hunks)
  • tilelang/engine/phase.py (2 hunks)
  • tilelang/language/reduce.py (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/issue/test_tilelang_issue_1001.py
🧬 Code graph analysis (5)
tilelang/language/reduce.py (1)
tilelang/language/utils.py (1)
  • buffer_to_tile_region (29-41)
tilelang/analysis/ast_printer.py (1)
tilelang/language/v2/builder.py (1)
  • PrimFunc (517-526)
testing/python/issue/test_tilelang_issue_1001.py (1)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-144)
src/op/reduce.cc (1)
src/op/gemm.cc (5)
  • strides (113-113)
  • NormalizeToBufferRegion (53-100)
  • NormalizeToBufferRegion (53-54)
  • MakeAccessPtrFromRegion (106-136)
  • MakeAccessPtrFromRegion (106-107)
tilelang/analysis/__init__.py (1)
tilelang/analysis/ast_printer.py (1)
  • ASTPrinter (8-24)
🪛 Ruff (0.14.5)
tilelang/analysis/ast_printer.py

20-20: Unused function argument: mod

(ARG001)


20-20: Unused function argument: ctx

(ARG001)

tilelang/analysis/__init__.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Agent
  • GitHub Check: CodeQL analysis (python)
  • GitHub Check: Quick Lint
🔇 Additional comments (4)
tilelang/language/reduce.py (1)

289-307: Region-based tl.cumsum call from Python looks consistent with new lowering

Switching cumsum to emit regions instead of raw access_ptr:

buffer_to_tile_region(src, "r"),
buffer_to_tile_region(dst, "w"),

aligns with the new NormalizeToBufferRegion / srcRegion_ / dstRegion_ flow and should correctly preserve view shapes (including T.view(smem, (1, hidden))).

src/op/reduce.h (1)

84-104: Region fields and reflection for ReduceOpNode / CumSumOpNode look correct

Adding srcRegion_ / dstRegion_ and exposing them as srcRegion / dstRegion via reflection is consistent with the new region-based lowering in reduce.cc. Clone behavior is preserved via copy-construction.

Also applies to: 135-151

src/op/reduce.cc (1)

579-599: CumSumOp constructor: region plumbing and dim bound check look good

Switching the constructor to:

  • Normalize args[0] / args[1] to srcRegion_ / dstRegion_ via NormalizeToBufferRegion, and
  • Derive src / dst buffers from those regions,

is exactly what’s needed to preserve view shapes (e.g., T.view(smem, (1, hidden))) instead of relying on the vmap-only buffer lookup. The added upper-bound check on dim vs. src->shape.size() also gives an earlier, clearer error.

Given the Python side normalizes negative dims before constructing CumSumOp, this check is sufficient.

tilelang/analysis/__init__.py (1)

3-3: LGTM! Static analysis hint is a false positive.

The noqa: F401 directive is intentional for this re-export pattern, making ASTPrinter available as part of the public API. The Ruff hint can be safely ignored.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/op/reduce.cc (1)

594-595: Remove commented-out code.

These commented lines represent the old implementation and should be removed to reduce code clutter.

Apply this diff:

-  // node->src = vmap[GetVarFromAccessPtr(args[0])];
-  // node->dst = vmap[GetVarFromAccessPtr(args[1])];
🧹 Nitpick comments (2)
testing/python/language/test_tilelang_language_atomic_add.py (1)

263-263: Inconsistent decorator usage.

This is the only decorator in the file using @tilelang.jit() with parentheses, while all other instances (lines 5, 40, 77, 112, 147, 173, 209, 298, 331) use @tilelang.jit without parentheses.

Apply this diff to maintain consistency:

-@tilelang.jit()
+@tilelang.jit
src/op/reduce.cc (1)

94-111: Clarify comment for 2D-only branch.

The comment "Compute row-major strides for ndim >= 2" suggests support for 3+ dimensions, but the ICHECK at line 85 restricts this function to exactly 1D or 2D buffers. Since CumSum only supports 1D and 2D (as stated at line 639), this restriction is appropriate.

Apply this diff to clarify the comment:

-    // Compute row-major strides for ndim >= 2
+    // Compute row-major strides for 2D case
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 887638e and 8f58e93.

📒 Files selected for processing (4)
  • src/op/reduce.cc (6 hunks)
  • testing/python/language/test_tilelang_language_atomic_add.py (1 hunks)
  • tilelang/analysis/ast_printer.py (1 hunks)
  • tilelang/language/reduce.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/language/reduce.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_atomic_add.py
🧬 Code graph analysis (2)
src/op/reduce.cc (1)
src/op/gemm.cc (5)
  • strides (113-113)
  • NormalizeToBufferRegion (53-100)
  • NormalizeToBufferRegion (53-54)
  • MakeAccessPtrFromRegion (106-136)
  • MakeAccessPtrFromRegion (106-107)
tilelang/analysis/ast_printer.py (1)
tilelang/language/v2/builder.py (1)
  • PrimFunc (517-526)
🪛 Ruff (0.14.5)
tilelang/analysis/ast_printer.py

19-19: Unused function argument: mod

(ARG001)


19-19: Unused function argument: ctx

(ARG001)

🔇 Additional comments (5)
tilelang/analysis/ast_printer.py (1)

19-21: LGTM!

The unused mod and ctx parameters flagged by static analysis are part of the required callback signature for prim_func_pass and cannot be removed. The implementation correctly applies the pre-order visitor to print statement types for debugging.

src/op/reduce.cc (4)

61-70: LGTM!

The new handling for builtin.tvm_access_ptr correctly extracts the buffer variable and constructs a full BufferRegion using the buffer's shape. This enables proper region-based normalization for tvm_access_ptr calls.


576-584: LGTM!

The helper correctly converts a Buffer to a BufferRegion by creating full ranges [0, extent) for each dimension. This normalization is essential for uniform region-based handling.


596-606: LGTM!

The region-based normalization correctly handles various input forms (BufferRegion, BufferLoad, tvm_access_ptr) and the new dimension validation (lines 602-605) prevents invalid cumsum operations with out-of-range dimensions.


623-625: Core fix for cumsum + view shape inference.

This change correctly builds access pointers from BufferRegion instead of raw arguments, ensuring proper shape inference when T.cumsum is applied to viewed buffers. The region-based approach accurately captures the buffer's dimensions and offsets, addressing issue #1001.

@LeiWang1999 LeiWang1999 merged commit 9dda774 into tile-ai:main Nov 24, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Failed to infer layout with T.cumsum + view

2 participants