Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Oct 20, 2025

This pull request adds support for shared-memory reduction operations in both CUDA and HIP backends, refactors the reduction lowering logic to handle different memory scopes, and introduces comprehensive tests for shared-memory reductions. The main focus is on enabling and validating efficient reductions (sum, max, min, abs-sum, abs-max, bitwise ops) in shared memory, with backend-specific optimizations for CUDA (warp size 32) and HIP (wave size 64).

Shared-memory reduction support

  • Added a generic SharedReduceWarp template in both src/tl_templates/cuda/reduce.h and src/tl_templates/hip/reduce.h for efficient warp/wave-level reductions, supporting sum, max, min, abs-sum, abs-max, bitwise and/or/xor, and accumulation logic. [1] [2]
  • Implemented backend-specific details for CUDA (warp size 32, __shfl_down_sync, __activemask) and HIP (wave size 64, __shfl_down). [1] [2]

Reduction lowering logic refactor

  • Refactored ReduceOpNode::Lower in src/op/reduce.cc to support shared-memory reductions, including scope checks, buffer handling, dimension checks, and backend-specific constraints. Now generates shared-memory reduction calls for CUDA/HIP, and falls back to thread-local reduction for fragments. [1] [2] [3]
  • Improved handling of initialization, accumulation, and buffer duplication for different reduction types, streamlining logic for sum, abs-sum, bitwise ops, and error handling for unsupported types. [1] [2]

Bitwise reduction support

  • Added BitAndOp, BitOrOp, and BitXorOp functors to HIP backend for completeness and parity with CUDA.

Comprehensive testing

  • Introduced testing/python/language/test_tilelang_language_reduce.py with extensive tests for shared-memory reductions (sum, max, min, abs-sum, abs-max), including correctness checks, clear/accumulate semantics, and backend validation.

Error handling and code quality

  • Improved error messages and assertions for dimension mismatches, unsupported scopes, and thread extent constraints, ensuring robustness and easier debugging. [1] [2]

These changes collectively enable efficient, robust, and well-tested shared-memory reductions across CUDA and HIP backends.

Summary by CodeRabbit

  • New Features

    • Enhanced reduction capabilities: warp/shared-memory reductions, optional absolute-value handling, optional accumulation, and new bitwise reduction operations; improved multi-stage local reductions and runtime guards.
  • Bug Fixes

    • Fixed CUDA dynamic memory error-check handling.
  • Tests

    • Reorganized and expanded reduction tests for shared/local and clear/non-clear modes; obsolete test modules consolidated/removed.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 20, 2025

Walkthrough

Refactors ReduceOp lowering to resolve buffer remaps and branch by buffer scope; adds warp-level SharedReduceWarp templates for CUDA/HIP (with abs/accumulate and bitwise ops); consolidates/expands reduction tests into a unified suite; and tweaks a CUDA error constant in wrapper.py.

Changes

Cohort / File(s) Summary
Core reduction lowering
src/op/reduce.cc
Reworks ReduceOpNode::Lower to resolve remapped buffers and branch by buffer scope; implements distinct local.fragment and shared/shared.dyn reduction paths, reorganizes init/duplication/reduction stages, adds runtime guards/fatal logs for unsupported combos, and returns an Evaluate(Call(...)) for the shared-memory path.
CUDA reduction templates
src/tl_templates/cuda/reduce.h
Adds SharedReduceWarp<Reducer, Threads, UseAbs, NeedAccumulate> with a device run implementing per-warp multi-destination reductions using shuffle intrinsics; supports optional absolute-value handling and optional accumulation.
HIP reduction templates
src/tl_templates/hip/reduce.h
Adds bitwise functors (BitAndOp, BitOrOp, BitXorOp), SharedReduceWarp template, and AllReduce scaffolding for warp/shared-memory reductions with assertions and accumulate/abs options.
Unified/expanded tests
testing/python/language/test_tilelang_language_reduce.py
Adds a comprehensive reduction test suite covering local and shared reductions (sum, max, min, abs-sum, abs-max), clear/non-clear modes, multiple dtypes/sizes, and PyTorch reference checks; introduces builders and runners for shared/local cases.
Removed standalone tests
testing/python/language/test_tilelang_language_reduce_max.py, testing/python/language/test_tilelang_language_reduce_sum.py
Deletes prior separate reduce_max and reduce_sum test modules; their functionality is consolidated into the new unified test file.
Wrapper change
tilelang/jit/adapter/wrapper.py
Changes CUDA error comparison constant from CUDA_SUCCESS to cudaSuccess in dynamic shared memory size path.

Sequence Diagram

sequenceDiagram
    participant Lower as ReduceOp Lowering
    participant Remap as Buffer Remap Lookup
    participant Scope as Scope Resolver
    participant Local as local.fragment Path
    participant Shared as shared / shared.dyn Path
    participant Warp as SharedReduceWarp (CUDA/HIP)
    participant Runtime as Runtime / Logs

    Lower->>Remap: get_buffer(src/dst)
    Remap-->>Scope: return buffer + scope
    Scope-->>Lower: src_scope, dst_scope

    alt both local.fragment
        Lower->>Local: validate dims, insert reduce dim
        Local->>Local: allocate/clear dup buffer if needed
        Local->>Local: per-dimension loops & reduction pipeline
        Local->>Lower: assemble final write-back (sum/abs/bitwise)
    else shared/shared.dyn
        Lower->>Shared: validate dims & platform thread extents
        Shared->>Warp: build params (total_dest, reduce_extent, tail, init)
        Warp->>Warp: per-warp reduction (shuffle/accumulate)
        Warp-->>Shared: return Evaluate(Call(...))
        Shared->>Lower: integrate call into schedule
    else
        Lower->>Runtime: Fatal("unsupported scope/type")
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • tzj-fxz
  • chengyupku
  • xysmlx

Poem

🐰 I hop through lanes of shared and local rows,

I shuffle bits where reduction grows,
Buffers remapped, threads hum in sync,
I fold the sums before you blink,
A carrot-coded cheer — then off I go!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The PR title "[Language] Efficient T.reduce_ with shared memory input/output" directly and accurately represents the main objective of the changeset. The title clearly identifies the API being enhanced (T.reduce_), specifies the primary improvement (shared memory support), and uses concrete terminology rather than generic phrasing. The changes across src/op/reduce.cc, the reduce.h templates for both CUDA and HIP, and the comprehensive test additions all align with this stated goal. The title is concise, avoids file listings or vague terms, and provides sufficient clarity for a developer scanning pull request history to understand the core contribution.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d06ec58 and 8db0195.

📒 Files selected for processing (1)
  • src/op/reduce.cc (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/reduce.cc (3)
src/layout/utils.cc (10)
  • CompressIterator (200-244)
  • CompressIterator (200-203)
  • iter_sum (174-186)
  • iter_sum (174-174)
  • ToVMap (254-260)
  • ToVMap (254-254)
  • mark (57-67)
  • mark (57-57)
  • mark (188-194)
  • mark (188-188)
src/target/utils.cc (8)
  • TargetIsHopper (49-54)
  • TargetIsHopper (49-49)
  • TargetIsSm100 (56-61)
  • TargetIsSm100 (56-56)
  • TargetIsCuda (11-13)
  • TargetIsCuda (11-11)
  • TargetIsRocm (14-16)
  • TargetIsRocm (14-14)
src/transform/loop_partition.cc (2)
  • PartitionLoop (61-163)
  • PartitionLoop (61-62)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (3)
src/op/reduce.cc (3)

178-186: LGTM! Clean refactoring to support multiple reduction paths.

The get_buffer helper cleanly resolves buffer remaps, and extracting scopes enables the new shared-memory branch. The logic is straightforward and correct.


362-364: LGTM! Clean helper for shared-memory scope detection.

The lambda provides a clear, reusable check for shared-memory scopes, improving readability.


426-428: LGTM! Improved error message for unsupported scope combinations.

The error message clearly communicates the unsupported scope pair and includes actual values, which aids debugging.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tilelang/jit/adapter/wrapper.py (2)

500-504: Bug: l2_persistent_map stored under stale function_name.

Inside this loop, function_name refers to the last device func from a previous loop, so you end up assigning the entire L2 map to the wrong key. Load the whole mapping once instead.

Apply this diff:

-            if "l2_persistent_map" in func.attrs:
-                self.l2_persistent_map[function_name] = func.attrs["l2_persistent_map"]
+            if "l2_persistent_map" in func.attrs:
+                # Attr is a per-kernel map: { kernel_name: {buffer_name: (hit_ratio, bytes), ...}, ... }
+                self.l2_persistent_map = dict(func.attrs["l2_persistent_map"])

1129-1133: Undefined attribute in error path.

self.arch.platform is not defined on TLWrapper/TLPyWrapper. Use self.target to report the unsupported target.

Apply this diff:

-        else:
-            raise ValueError(f"Unsupported platform: {self.arch.platform}")
+        else:
+            raise ValueError(f"Unsupported target: {self.target}")

And similarly in TLPyWrapper.wrap:

-        else:
-            raise ValueError(f"Unsupported platform: {self.arch.platform}")
+        else:
+            raise ValueError(f"Unsupported target: {self.target}")

Also applies to: 1151-1154

src/op/reduce.cc (1)

56-68: Fix overflow in init constants for int/uint min/max.

1 << ... overflows for 32/64-bit widths (UB). Use 64-bit shifts then cast.

Apply this diff:

-    if (is_int) {
-      return make_const(dst->dtype, -(1 << (bits - 1)));
+    if (is_int) {
+      return make_const(dst->dtype, -(int64_t(1) << (bits - 1)));
     } else if (is_uint) {
       return make_const(dst->dtype, 0);
     } else {
       return make_const(dst->dtype, -INFINITY);
     }
@@
-    if (is_int) {
-      return make_const(dst->dtype, (1 << (bits - 1)) - 1);
+    if (is_int) {
+      return make_const(dst->dtype, (int64_t(1) << (bits - 1)) - 1);
     } else if (is_uint) {
-      return make_const(dst->dtype, (1 << bits) - 1);
+      return make_const(dst->dtype, (uint64_t(1) << bits) - 1);
     } else {
       return make_const(dst->dtype, INFINITY);
     }

Also applies to: 64-71, 74-81

🧹 Nitpick comments (5)
tilelang/jit/adapter/wrapper.py (1)

732-763: NVRTC: device index detection may default to 0.

If no ctypes.c_void_p args are present, device_index remains 0, which can set attributes on the wrong device in multi‑GPU runs. Consider falling back to the current device from the runtime or annotating kernels with device id.

Would you like a small patch to fetch the current CUDA device (via your cuda.bindings.runtime shim) when no buffer args are found?

src/tl_templates/cuda/reduce.h (1)

74-77: Specify shuffle width to match kWarpSize.

Pass kWarpSize to __shfl_down_sync for clarity and safety.

Apply this diff:

-      for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) {
-        T other = __shfl_down_sync(mask, partial, offset);
+      for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) {
+        T other = __shfl_down_sync(mask, partial, offset, kWarpSize);
src/op/reduce.cc (2)

197-203: Nit: misleading scalar reduce message.

The check enforces scalar output; message says “not implemented.” Consider clarifying.

Apply this diff:

-      ICHECK(is_one(dst_layout->OutputShape().back()))
-          << "Reduce for scalar not implemented.";
+      ICHECK(is_one(dst_layout->OutputShape().back()))
+          << "Expect scalar output layout (last extent = 1) for 1D reduce.";

307-314: Workspace sizing for AllReduce (minor).

You always allocate workspace sized by T.thread_bounds->extent, but only need it when reducing_threads >= 32. Consider sizing by reducing_threads to save shared/local memory.

testing/python/language/test_tilelang_language_reduce.py (1)

133-143: Add bitwise reduction tests for parity and regressions.

Since BitAnd/BitOr/BitXor were added, include shared-memory tests verifying:

  • and/or/xor over dim=1 with clear=True/False
  • integer dtypes (e.g., int32, uint32)

I can draft minimal builders mirroring reduce_*_ss.

Also applies to: 145-163, 186-204, 206-224

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fd6cec5 and 1d31ee3.

📒 Files selected for processing (7)
  • src/op/reduce.cc (1 hunks)
  • src/tl_templates/cuda/reduce.h (1 hunks)
  • src/tl_templates/hip/reduce.h (1 hunks)
  • testing/python/language/test_tilelang_language_reduce.py (1 hunks)
  • testing/python/language/test_tilelang_language_reduce_max.py (0 hunks)
  • testing/python/language/test_tilelang_language_reduce_sum.py (0 hunks)
  • tilelang/jit/adapter/wrapper.py (1 hunks)
💤 Files with no reviewable changes (2)
  • testing/python/language/test_tilelang_language_reduce_sum.py
  • testing/python/language/test_tilelang_language_reduce_max.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/tl_templates/cuda/reduce.h (1)
src/tl_templates/hip/reduce.h (1)
  • void (46-87)
src/tl_templates/hip/reduce.h (1)
src/tl_templates/cuda/reduce.h (3)
  • T (208-280)
  • void (46-87)
  • run (140-201)
src/op/reduce.cc (3)
src/layout/utils.cc (10)
  • CompressIterator (200-244)
  • CompressIterator (200-203)
  • iter_sum (174-186)
  • iter_sum (174-174)
  • ToVMap (254-260)
  • ToVMap (254-254)
  • mark (57-67)
  • mark (57-57)
  • mark (188-194)
  • mark (188-188)
src/target/utils.cc (8)
  • TargetIsHopper (49-54)
  • TargetIsHopper (49-49)
  • TargetIsSm100 (56-61)
  • TargetIsSm100 (56-56)
  • TargetIsCuda (11-13)
  • TargetIsCuda (11-11)
  • TargetIsRocm (14-16)
  • TargetIsRocm (14-14)
src/transform/loop_partition.cc (2)
  • PartitionLoop (61-163)
  • PartitionLoop (61-62)
testing/python/language/test_tilelang_language_reduce.py (10)
src/tl_templates/cuda/reduce.h (2)
  • tl (5-88)
  • T (208-280)
src/tl_templates/hip/reduce.h (1)
  • tl (5-88)
tilelang/testing/__init__.py (1)
  • set_random_seed (30-35)
tilelang/language/allocate.py (2)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
tilelang/language/copy.py (1)
  • copy (10-86)
tilelang/jit/__init__.py (1)
  • compile (34-90)
tilelang/jit/kernel.py (2)
  • out_idx (462-463)
  • get_profiler (376-392)
tilelang/profiler/__init__.py (1)
  • assert_allclose (76-145)
tilelang/language/reduce.py (5)
  • reduce_max (50-68)
  • reduce_sum (87-109)
  • reduce_min (71-84)
  • reduce_abssum (112-124)
  • reduce_absmax (127-139)
tilelang/language/fill.py (2)
  • fill (9-21)
  • clear (24-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (3)
tilelang/jit/adapter/wrapper.py (1)

13-19: Correct CUDA runtime error check.

Switching to cudaSuccess matches cudaError_t/cuda runtime usage. Good fix.

src/tl_templates/hip/reduce.h (1)

25-41: HIP reduction primitives look good.

Wave-size handling, shuffle width, and bitwise functors align with CUDA parity.

Also applies to: 43-88, 90-115

src/op/reduce.cc (1)

387-392: Shared-memory reduce lowering looks consistent.

  • abs handling uses UseAbs flag
  • accumulate semantics map to NeedAccumulate
  • dst passed with RW access_ptr

LGTM.

Also applies to: 404-418

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/reduce.cc (1)

141-176: Update outdated documentation.

The comment block describes the old implementation that only supported local.fragment scope. With the addition of shared-memory reduction support (lines 366-424), this documentation is now outdated and misleading.

Update the documentation to reflect both reduction paths and their respective behaviors.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1d31ee3 and d06ec58.

📒 Files selected for processing (1)
  • src/op/reduce.cc (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/reduce.cc (3)
src/layout/utils.cc (10)
  • CompressIterator (200-244)
  • CompressIterator (200-203)
  • iter_sum (174-186)
  • iter_sum (174-174)
  • ToVMap (254-260)
  • ToVMap (254-254)
  • mark (57-67)
  • mark (57-57)
  • mark (188-194)
  • mark (188-188)
src/target/utils.cc (8)
  • TargetIsHopper (49-54)
  • TargetIsHopper (49-49)
  • TargetIsSm100 (56-61)
  • TargetIsSm100 (56-56)
  • TargetIsCuda (11-13)
  • TargetIsCuda (11-11)
  • TargetIsRocm (14-16)
  • TargetIsRocm (14-14)
src/transform/loop_partition.cc (2)
  • PartitionLoop (61-163)
  • PartitionLoop (61-62)
🔇 Additional comments (4)
src/op/reduce.cc (4)

178-185: LGTM!

The get_buffer helper cleanly handles buffer remapping, and extracting scopes upfront enables clear branching logic between local.fragment and shared-memory paths.


341-359: LGTM!

The final assembly correctly wraps the reduction body in parallel loops, applies loop partitioning for thread-level parallelism, and handles the scalar case with an appropriate thread guard. The conditional allocation of the clear buffer is also properly placed.


398-407: LGTM!

The extent calculations correctly compute the reduction dimension size, the tail extent for dimensions beyond the reduction axis (defaulting to 1 when reducing the last dimension), and the total destination size. These values properly parameterize the SharedReduceWarp call.


426-428: LGTM!

The final error handler provides a clear, informative message for unsupported scope combinations, making debugging easier when an unexpected case is encountered.

Comment on lines +198 to +199
ICHECK(is_one(dst_layout->OutputShape().back()))
<< "Reduce for scalar not implemented.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Clarify confusing error message.

The assertion checks that dst_layout->OutputShape().back() is 1, but the error message states "Reduce for scalar not implemented." This is misleading: if the output shape is not 1, then we're not reducing to a scalar, so the error message doesn't match the failure condition.

Consider a clearer message like:

ICHECK(is_one(dst_layout->OutputShape().back()))
    << "Expected scalar output (extent=1) for 1D reduction, got extent="
    << dst_layout->OutputShape().back();
🤖 Prompt for AI Agents
In src/op/reduce.cc around lines 198 to 199, the ICHECK asserts that
dst_layout->OutputShape().back() is 1 but the error message "Reduce for scalar
not implemented." is misleading; change the message to explicitly state the
expected extent and the actual value so it reflects the failure condition (e.g.,
"Expected scalar output (extent=1) for 1D reduction, got extent=<value>").
Update the ICHECK message to include dst_layout->OutputShape().back() so the log
shows the actual extent when the check fails.

Comment on lines +323 to +339
if (need_duplicate) {
PrimExpr src_val = BufferLoad(clear_buffer, dst_indices);
PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices);
PrimExpr update;
if (this->type->isSum() || this->type->isAbsSum()) {
update = dst_val + src_val;
} else if (this->type->isBitAnd()) {
update = this->clear ? src_val : bitwise_and(dst_val, src_val);
} else if (this->type->isBitOr()) {
update = bitwise_or(dst_val, src_val);
} else if (this->type->isBitXor()) {
update = bitwise_xor(dst_val, src_val);
} else {
stmts.push_back(BufferStore(
dst_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices));
LOG(FATAL) << "Unsupported reduce type: " << this->type->type;
}
} else if (this->type->isBitOr()) {
stmts.push_back(
BufferStore(dst_buffer,
bitwise_or(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else if (this->type->isBitXor()) {
stmts.push_back(
BufferStore(dst_buffer,
bitwise_xor(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
stmts.push_back(BufferStore(dst_buffer, update, dst_indices));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove dead code in BitAnd ternary check.

At line 330, the code checks this->clear ? src_val : bitwise_and(dst_val, src_val). However, this entire block (lines 323-339) only executes when need_duplicate is true. For BitAnd, need_duplicate is set to true only when !this->clear (lines 237-238). Therefore, when reaching line 330, this->clear is always false, making the true branch of the ternary dead code.

Apply this diff to simplify:

     } else if (this->type->isBitAnd()) {
-      update = this->clear ? src_val : bitwise_and(dst_val, src_val);
+      update = bitwise_and(dst_val, src_val);
     } else if (this->type->isBitOr()) {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (need_duplicate) {
PrimExpr src_val = BufferLoad(clear_buffer, dst_indices);
PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices);
PrimExpr update;
if (this->type->isSum() || this->type->isAbsSum()) {
update = dst_val + src_val;
} else if (this->type->isBitAnd()) {
update = this->clear ? src_val : bitwise_and(dst_val, src_val);
} else if (this->type->isBitOr()) {
update = bitwise_or(dst_val, src_val);
} else if (this->type->isBitXor()) {
update = bitwise_xor(dst_val, src_val);
} else {
stmts.push_back(BufferStore(
dst_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices));
LOG(FATAL) << "Unsupported reduce type: " << this->type->type;
}
} else if (this->type->isBitOr()) {
stmts.push_back(
BufferStore(dst_buffer,
bitwise_or(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else if (this->type->isBitXor()) {
stmts.push_back(
BufferStore(dst_buffer,
bitwise_xor(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
stmts.push_back(BufferStore(dst_buffer, update, dst_indices));
}
if (need_duplicate) {
PrimExpr src_val = BufferLoad(clear_buffer, dst_indices);
PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices);
PrimExpr update;
if (this->type->isSum() || this->type->isAbsSum()) {
update = dst_val + src_val;
} else if (this->type->isBitAnd()) {
update = bitwise_and(dst_val, src_val);
} else if (this->type->isBitOr()) {
update = bitwise_or(dst_val, src_val);
} else if (this->type->isBitXor()) {
update = bitwise_xor(dst_val, src_val);
} else {
LOG(FATAL) << "Unsupported reduce type: " << this->type->type;
}
stmts.push_back(BufferStore(dst_buffer, update, dst_indices));
}
🤖 Prompt for AI Agents
In src/op/reduce.cc around lines 323 to 339, the BitAnd branch uses a ternary
this->clear ? src_val : bitwise_and(dst_val, src_val) but need_duplicate is only
true when !this->clear, so the true branch is dead; simplify by replacing the
ternary with the direct expression bitwise_and(dst_val, src_val) for the BitAnd
case (remove the unreachable branch and keep the update assignment using
bitwise_and).

Comment on lines +366 to 424
if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) {
Buffer src_buffer = get_buffer(this->src);
Buffer dst_buffer = get_buffer(this->dst);

size_t src_dim = src_buffer->shape.size();
size_t dst_dim = dst_buffer->shape.size();
bool is_1d_reduce = (src_dim == dst_dim && dst_dim == 1);
if (!is_1d_reduce) {
ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch.";
} else {
ICHECK_EQ(dst_dim, 1U) << "Expect scalar layout for 1D reduce.";
}

auto thread_extent = as_const_int(T.thread_bounds->extent);
ICHECK(thread_extent)
<< "Shared-memory reduce requires static thread extent.";
int threads = *thread_extent;

if (TargetIsCuda(T.target)) {
ICHECK_EQ(threads % 32, 0)
<< "Shared reduce expects blockDim.x to be a multiple of 32 on CUDA.";
} else if (TargetIsRocm(T.target)) {
ICHECK_EQ(threads % 64, 0)
<< "Shared reduce expects blockDim.x to be a multiple of 64 on HIP.";
}

bool use_abs = this->type->isAbsSum() || this->type->isAbsMax();
bool need_accumulate =
(!this->clear) && (this->type->isSum() || this->type->isAbsSum() ||
this->type->isBitAnd() || this->type->isBitOr() ||
this->type->isBitXor());

PrimExpr reduce_extent = src_buffer->shape[this->dim];
PrimExpr tail_extent = make_const(DataType::Int(32), 1);
for (size_t i = this->dim + 1; i < src_dim; ++i) {
tail_extent = analyzer->Simplify(tail_extent * src_buffer->shape[i]);
}

PrimExpr total_dest = make_const(DataType::Int(32), 1);
for (size_t i = 0; i < dst_dim; ++i) {
total_dest = analyzer->Simplify(total_dest * dst_buffer->shape[i]);
}

std::stringstream ss;
std::string reducer = this->MakeCodegenReducer();
ss << "tl::SharedReduceWarp<" << reducer << ", " << threads << ", "
<< (use_abs ? "true" : "false") << ", "
<< (need_accumulate ? "true" : "false") << ">::run";

Array<PrimExpr> call_args = {StringImm(ss.str()),
src_buffer.access_ptr(1),
dst_buffer.access_ptr(3),
cast(DataType::Int(32), total_dest),
cast(DataType::Int(32), reduce_extent),
cast(DataType::Int(32), tail_extent),
this->MakeInitValue()};

return Evaluate(Call(dst_buffer->dtype, builtin::call_extern(), call_args));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Add target validation for shared-memory reduction path.

The shared-memory reduction path generates a tl::SharedReduceWarp call that is only available for CUDA and HIP targets (per the PR description and template files). However, lines 384-390 only validate thread extents for CUDA/HIP without ensuring the target is one of these platforms. If the target is neither CUDA nor HIP, the code will still generate the call, potentially causing compilation failures.

Add target validation at the beginning of the shared-memory path:

 if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) {
+  ICHECK(TargetIsCuda(T.target) || TargetIsRocm(T.target))
+      << "Shared-memory reduction is only supported on CUDA and HIP targets, got: "
+      << T.target;
+
   Buffer src_buffer = get_buffer(this->src);

Verify: Does shared-memory reduction require loop wrapping?

The shared-memory reduction path returns a bare Evaluate(Call(...)) statement (line 423) without any loop wrapping, thread guards, or partitioning, unlike the local.fragment path (lines 341-359). This suggests the SharedReduceWarp template internally handles all thread coordination and iteration.

Please confirm this is the intended behavior. If the template handles everything internally, consider adding a comment explaining why no loop wrapping is needed, for example:

// SharedReduceWarp template handles all thread coordination and iteration internally
return Evaluate(Call(dst_buffer->dtype, builtin::call_extern(), call_args));

I'll verify the review comment by examining the codebase. Let me search for the specific code and patterns mentioned.


🏁 Script executed:

#!/bin/bash
# Verify the code snippet and check target validation patterns

# 1. Check if src/op/reduce.cc exists and view the specific lines
echo "=== Checking src/op/reduce.cc lines 366-424 ==="
if [ -f "src/op/reduce.cc" ]; then
  sed -n '366,424p' src/op/reduce.cc
else
  echo "File not found, searching for reduce.cc..."
  fd -t f -i "reduce\.cc"
fi

Length of output: 2585


🏁 Script executed:

#!/bin/bash
# Look for TargetIsCuda and TargetIsRocm usage patterns

echo "=== Searching for TargetIsCuda/TargetIsRocm usage patterns ==="
rg -n "TargetIsCuda|TargetIsRocm" -C 5 --type cc --type cpp | head -60

Length of output: 89


🏁 Script executed:

#!/bin/bash
# Check for SharedReduceWarp template references

echo "=== Searching for SharedReduceWarp references ==="
rg -rn "SharedReduceWarp" --type cc --type cpp --type h

Length of output: 76


🏁 Script executed:

#!/bin/bash
# Look at the local.fragment path mentioned (lines 341-359) for comparison

echo "=== Checking local.fragment reduction path (lines 341-359) ==="
if [ -f "src/op/reduce.cc" ]; then
  sed -n '341,359p' src/op/reduce.cc
fi

Length of output: 764


🏁 Script executed:

#!/bin/bash
# Check for other target validation patterns in reduce.cc

echo "=== Searching for target validation patterns in reduce.cc ==="
rg -n "TargetIs" src/op/reduce.cc -B 2 -A 2

Length of output: 1073


I need to check the broader context to understand if there's a guard before this shared-memory path, and verify target-specific patterns.


🏁 Script executed:

#!/bin/bash
# Check broader context around the shared-memory reduction (lines 350-430)

echo "=== Full context around shared-memory reduction path ==="
sed -n '320,430p' src/op/reduce.cc

Length of output: 4312


🏁 Script executed:

#!/bin/bash
# Search for SharedReduceWarp references without type flags

echo "=== Searching for SharedReduceWarp ==="
rg -rn "SharedReduceWarp"

Length of output: 380


🏁 Script executed:

#!/bin/bash
# Check if there's a guard before the shared-memory path

echo "=== Looking for conditional guards before line 366 ==="
sed -n '330,370p' src/op/reduce.cc

Length of output: 1508


🏁 Script executed:

#!/bin/bash
# Search for TargetIsCuda/TargetIsRocm patterns in reduce.cc without type flags

echo "=== All TargetIsCuda/TargetIsRocm usage in reduce.cc ==="
rg -n "TargetIsCuda\|TargetIsRocm" src/op/reduce.cc -B 3 -A 3

Length of output: 56


🏁 Script executed:

#!/bin/bash
# Look for where this reduction code is called from (find method signature)

echo "=== Finding method containing line 366 ==="
sed -n '280,370p' src/op/reduce.cc | grep -n "if\|for\|void\|Stmt\|PrimExpr"

Length of output: 1270


I need to examine the template implementations to verify whether loop wrapping is necessary.


🏁 Script executed:

#!/bin/bash
# View CUDA reduce template implementation

echo "=== CUDA reduce.h template ==="
head -100 src/tl_templates/cuda/reduce.h

Length of output: 3101


🏁 Script executed:

#!/bin/bash
# View HIP reduce template implementation

echo "=== HIP reduce.h template ==="
head -100 src/tl_templates/hip/reduce.h

Length of output: 2994


🏁 Script executed:

#!/bin/bash
# Search for "run" function/method in templates to understand call signature

echo "=== Searching for 'run' method in templates ==="
rg -n "run" src/tl_templates/cuda/reduce.h src/tl_templates/hip/reduce.h -A 5

Length of output: 4315


🏁 Script executed:

#!/bin/bash
# Check full method signature containing the shared-memory path

echo "=== Full method containing shared-memory reduction ==="
sed -n '100,200p' src/op/reduce.cc | head -50

Length of output: 1615


🏁 Script executed:

#!/bin/bash
# Search for method signature/function name

echo "=== Looking for Stmt or PrimExpr return methods around line 366 ==="
sed -n '100,370p' src/op/reduce.cc | grep -n "Stmt\|PrimExpr" | tail -20

Length of output: 1165


Add target validation guard at the beginning of the shared-memory reduction path.

The shared-memory reduction path generates a tl::SharedReduceWarp template call that is only available for CUDA and HIP targets. However, the code (lines 384-389) checks thread extents for CUDA/HIP without guarding against other targets. If the target is neither CUDA nor HIP, the code will still generate the template call, causing compilation to fail since templates only exist in src/tl_templates/cuda/reduce.h and src/tl_templates/hip/reduce.h.

 if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) {
+  ICHECK(TargetIsCuda(T.target) || TargetIsRocm(T.target))
+      << "Shared-memory reduction is only supported on CUDA and HIP targets, got: "
+      << T.target;
+
   Buffer src_buffer = get_buffer(this->src);

Regarding loop wrapping: The SharedReduceWarp template handles all thread coordination and iteration internally via device-side for-loops, so the bare Evaluate(Call(...)) (line 423) is correct. This differs from the local.fragment path only because that path operates at the TIR level with host-side code, whereas the shared-memory path invokes an external device kernel.

@LeiWang1999
Copy link
Member Author

@codex review

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. 🚀

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant