Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Oct 11, 2025

This pull request adds support for 1D cumulative sum (CumSum) operations in both the CUDA backend and the TileLang language, including corresponding tests for shared memory and fragment scopes. The changes include a new implementation for 1D CumSum on CUDA, updates to the lowering logic to handle 1D buffers, and new Python tests to verify correctness.

1D CumSum Support

  • Added a new CUDA template CumSum1D in reduce.h for efficient 1D cumulative sum operations, supporting both forward and reverse directions.
  • Updated the lowering logic in reduce.cc to handle 1D CumSum operations, including error checking for dimensionality and proper argument passing to the CUDA kernel.

Testing Enhancements

  • Added new Python test functions (cumsum_smem_test_1d, cumsum_fragment_test_1d, and run_cumsum_1d) to test 1D CumSum behavior in both shared memory and fragment scopes. These tests include reference implementations and correctness checks against the CUDA kernel output.
  • Introduced new test cases (test_cumsum_smem_1d, test_cumsum_fragment_1d) to the test suite to ensure comprehensive coverage of 1D CumSum operations, including reverse mode.

Summary by CodeRabbit

  • New Features

    • Added 1D cumulative sum support with forward and reverse modes, alongside existing 2D behavior.
    • Automatically handles 1D vs 2D inputs; provides clearer errors for unsupported shapes.
  • Tests

    • Added comprehensive 1D cumulative sum tests (including shared-memory and fragment variants) with reference validation.
  • Documentation

    • Expanded cumsum examples to cover 1D usage and 2D reverse accumulation.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 11, 2025

Walkthrough

Adds 1D CumSum support end-to-end: new CUDA template CumSum1D, lowers CumSumOp to call 1D or 2D kernels based on input rank, errors for >2D, updates tests to cover 1D smem/fragment variants, and expands docstring examples. No other public APIs changed.

Changes

Cohort / File(s) Summary
Lowering logic: CumSum rank-aware branching
src/op/reduce.cc
CumSumOpNode::Lower now branches by input rank: calls tl::CumSum1D<...>::run for 1D (dim must be 0) and tl::CumSum2D<...>::run for 2D; raises fatal error for >2D. Adjusts extern arg construction accordingly.
CUDA template: 1D cumulative sum
src/tl_templates/cuda/reduce.h
Introduces template <int threads, bool reverse> struct CumSum1D with run(dst, src, N) implementing inclusive scan using warp shuffles; supports forward and reverse modes; static_assert on thread count. No changes to existing templates.
Tests: 1D CumSum coverage
testing/python/language/test_tilelang_language_cumsum.py
Adds 1D kernels/tests: smem and fragment variants, runner run_cumsum_1d, and entry tests test_cumsum_smem_1d/test_cumsum_fragment_1d, including reverse and reference checks.
Docs: cumsum examples
tilelang/language/reduce.py
Expands cumsum docstring with 1D and 2D usage examples, including reverse accumulation. No logic changes.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant IR as CumSumOpNode::Lower
  participant Extern as Extern Call Builder
  participant TL1D as tl::CumSum1D
  participant TL2D as tl::CumSum2D
  participant Err as Fatal Error

  IR->>IR: Inspect input rank (ndim)
  alt ndim == 1
    IR->>Extern: Build args (dst, src, N, threads, reverse)
    Extern->>TL1D: run(dst, src, N)
    Note over TL1D: Inclusive scan (forward or reverse)
  else ndim == 2
    IR->>Extern: Build args (dst, src, M, N, threads, dim, reverse)
    Extern->>TL2D: run(dst, src, M, N)
    Note over TL2D: Inclusive scan along chosen dim
  else (>2)
    IR-->>Err: TL_FAIL("CumSum supports 1D/2D only")
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

A hare with threads in tidy rows,
Shuffles sum where data flows.
One lane, two lanes—watch me glide,
Forward, backward—prefix pride.
Tests now hop in joyful spree,
Docs leave trails for all to see.
Cumsum burrows: 1D, 2D! 🥕✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 7.14% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly captures the primary change by indicating the implementation of the new 1D cumulative sum operator “CumSum1D” within the TileOp framework, making it immediately clear to reviewers what the core addition is.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
src/op/reduce.cc (1)

423-441: The 1D/2D branching logic is correct.

The implementation properly distinguishes between 1D and 2D cases, with appropriate dimension checks and error handling for unsupported dimensionalities. The argument construction and template instantiation for both CumSum1D and CumSum2D are correct.

Consider using string access modes for consistency.

Lines 430 and 435 use integer access modes (1, 3) for access_ptr, while other parts of the codebase (e.g., lines 37-38 in ReduceOpNode) use string modes ("r", "w"). Using strings would improve readability and maintain consistency.

Apply this diff for better consistency:

       args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
               src->shape[0]};
+      args = {StringImm(ss.str()), src.access_ptr("r"), dst.access_ptr("w"),
+              src->shape[0]};
     } else if (ndim == 2) {
       ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
          << (reverse ? "true" : "false") << ">::run";
       args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
               src->shape[0], src->shape[1]};
+      args = {StringImm(ss.str()), src.access_ptr("r"), dst.access_ptr("w"),
+              src->shape[0], src->shape[1]};
testing/python/language/test_tilelang_language_cumsum.py (1)

74-141: Well-structured 1D cumsum tests with correct reference implementation.

The 1D test functions properly mirror the 2D test structure and correctly validate block-local cumulative sum behavior:

  • Test kernels (cumsum_smem_test_1d, cumsum_fragment_test_1d): Correctly implement 1D cumsum in both shared memory and fragment scopes.
  • Reference implementation (lines 123-136): Accurately models block-local cumsum, processing each block independently and handling reverse mode by flipping before/after the cumsum operation.
  • Test coverage: Includes both forward and reverse modes, matching the capabilities of the underlying CUDA kernel.

Extract exception message to reduce line length.

Line 118 has a long error message embedded in the raise statement. Consider extracting it to a variable for better readability.

Apply this diff:

-        raise ValueError(f"Unknown scope {scope}")
+        msg = f"Unknown scope {scope}. Expected 'smem' or 'fragment'."
+        raise ValueError(msg)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0ae183d and 441d05e.

📒 Files selected for processing (4)
  • src/op/reduce.cc (2 hunks)
  • src/tl_templates/cuda/reduce.h (1 hunks)
  • testing/python/language/test_tilelang_language_cumsum.py (2 hunks)
  • tilelang/language/reduce.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_cumsum.py (4)
tilelang/language/reduce.py (1)
  • cumsum (157-207)
tilelang/language/allocate.py (2)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
tilelang/language/copy.py (1)
  • copy (10-86)
tilelang/jit/__init__.py (1)
  • compile (34-90)
🪛 Ruff (0.13.3)
testing/python/language/test_tilelang_language_cumsum.py

118-118: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-metal
  • GitHub Check: format-check
🔇 Additional comments (3)
tilelang/language/reduce.py (1)

163-185: LGTM! Clear and helpful documentation examples.

The added examples effectively demonstrate both 1D and 2D cumsum usage patterns, showing proper buffer allocation, copying, and the cumsum operation with different configurations (including reverse mode). These examples will help users understand the API.

src/tl_templates/cuda/reduce.h (1)

71-137: Excellent implementation of 1D cumulative sum with warp-level primitives.

The CumSum1D template correctly implements both forward and reverse cumulative sum using efficient warp shuffle operations:

  • Forward mode: Uses __shfl_up_sync for inclusive prefix sum within each segment, correctly propagates carry from left to right across segments.
  • Reverse mode: Uses __shfl_down_sync for reverse cumulative sum within each segment, correctly propagates carry from right to left across segments.
  • Edge cases: Properly handles N ≤ 0 (early return) and partial segments (boundary checks at lines 93, 104, 116, 127).
  • Design: Restricting to first SEG threads (lines 84-85) is appropriate for warp-level operations and sequential carry dependency.
testing/python/language/test_tilelang_language_cumsum.py (1)

164-171: LGTM! Test entry points provide good coverage.

The test functions exercise both shared memory and fragment scopes with forward and reverse modes, providing comprehensive coverage of the 1D cumsum functionality.

@LeiWang1999
Copy link
Member Author

Local test pass, merged.

@LeiWang1999 LeiWang1999 merged commit 747381a into tile-ai:main Oct 11, 2025
7 of 8 checks passed
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* support cumsum-1d

* cumsum 1d support
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant