-
Notifications
You must be signed in to change notification settings - Fork 332
[Language] Efficient T.reduce_ with shared memory input/output
#1080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Language] Efficient T.reduce_ with shared memory input/output
#1080
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughRefactors ReduceOp lowering to resolve buffer remaps and branch by buffer scope; adds warp-level SharedReduceWarp templates for CUDA/HIP (with abs/accumulate and bitwise ops); consolidates/expands reduction tests into a unified suite; and tweaks a CUDA error constant in wrapper.py. Changes
Sequence DiagramsequenceDiagram
participant Lower as ReduceOp Lowering
participant Remap as Buffer Remap Lookup
participant Scope as Scope Resolver
participant Local as local.fragment Path
participant Shared as shared / shared.dyn Path
participant Warp as SharedReduceWarp (CUDA/HIP)
participant Runtime as Runtime / Logs
Lower->>Remap: get_buffer(src/dst)
Remap-->>Scope: return buffer + scope
Scope-->>Lower: src_scope, dst_scope
alt both local.fragment
Lower->>Local: validate dims, insert reduce dim
Local->>Local: allocate/clear dup buffer if needed
Local->>Local: per-dimension loops & reduction pipeline
Local->>Lower: assemble final write-back (sum/abs/bitwise)
else shared/shared.dyn
Lower->>Shared: validate dims & platform thread extents
Shared->>Warp: build params (total_dest, reduce_extent, tail, init)
Warp->>Warp: per-warp reduction (shuffle/accumulate)
Warp-->>Shared: return Evaluate(Call(...))
Shared->>Lower: integrate call into schedule
else
Lower->>Runtime: Fatal("unsupported scope/type")
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code graph analysis (1)src/op/reduce.cc (3)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
🔇 Additional comments (3)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tilelang/jit/adapter/wrapper.py (2)
500-504: Bug: l2_persistent_map stored under stalefunction_name.Inside this loop,
function_namerefers to the last device func from a previous loop, so you end up assigning the entire L2 map to the wrong key. Load the whole mapping once instead.Apply this diff:
- if "l2_persistent_map" in func.attrs: - self.l2_persistent_map[function_name] = func.attrs["l2_persistent_map"] + if "l2_persistent_map" in func.attrs: + # Attr is a per-kernel map: { kernel_name: {buffer_name: (hit_ratio, bytes), ...}, ... } + self.l2_persistent_map = dict(func.attrs["l2_persistent_map"])
1129-1133: Undefined attribute in error path.
self.arch.platformis not defined on TLWrapper/TLPyWrapper. Useself.targetto report the unsupported target.Apply this diff:
- else: - raise ValueError(f"Unsupported platform: {self.arch.platform}") + else: + raise ValueError(f"Unsupported target: {self.target}")And similarly in TLPyWrapper.wrap:
- else: - raise ValueError(f"Unsupported platform: {self.arch.platform}") + else: + raise ValueError(f"Unsupported target: {self.target}")Also applies to: 1151-1154
src/op/reduce.cc (1)
56-68: Fix overflow in init constants for int/uint min/max.
1 << ...overflows for 32/64-bit widths (UB). Use 64-bit shifts then cast.Apply this diff:
- if (is_int) { - return make_const(dst->dtype, -(1 << (bits - 1))); + if (is_int) { + return make_const(dst->dtype, -(int64_t(1) << (bits - 1))); } else if (is_uint) { return make_const(dst->dtype, 0); } else { return make_const(dst->dtype, -INFINITY); } @@ - if (is_int) { - return make_const(dst->dtype, (1 << (bits - 1)) - 1); + if (is_int) { + return make_const(dst->dtype, (int64_t(1) << (bits - 1)) - 1); } else if (is_uint) { - return make_const(dst->dtype, (1 << bits) - 1); + return make_const(dst->dtype, (uint64_t(1) << bits) - 1); } else { return make_const(dst->dtype, INFINITY); }Also applies to: 64-71, 74-81
🧹 Nitpick comments (5)
tilelang/jit/adapter/wrapper.py (1)
732-763: NVRTC: device index detection may default to 0.If no
ctypes.c_void_pargs are present,device_indexremains0, which can set attributes on the wrong device in multi‑GPU runs. Consider falling back to the current device from the runtime or annotating kernels with device id.Would you like a small patch to fetch the current CUDA device (via your
cuda.bindings.runtimeshim) when no buffer args are found?src/tl_templates/cuda/reduce.h (1)
74-77: Specify shuffle width to match kWarpSize.Pass
kWarpSizeto__shfl_down_syncfor clarity and safety.Apply this diff:
- for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { - T other = __shfl_down_sync(mask, partial, offset); + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + T other = __shfl_down_sync(mask, partial, offset, kWarpSize);src/op/reduce.cc (2)
197-203: Nit: misleading scalar reduce message.The check enforces scalar output; message says “not implemented.” Consider clarifying.
Apply this diff:
- ICHECK(is_one(dst_layout->OutputShape().back())) - << "Reduce for scalar not implemented."; + ICHECK(is_one(dst_layout->OutputShape().back())) + << "Expect scalar output layout (last extent = 1) for 1D reduce.";
307-314: Workspace sizing for AllReduce (minor).You always allocate workspace sized by
T.thread_bounds->extent, but only need it whenreducing_threads >= 32. Consider sizing byreducing_threadsto save shared/local memory.testing/python/language/test_tilelang_language_reduce.py (1)
133-143: Add bitwise reduction tests for parity and regressions.Since BitAnd/BitOr/BitXor were added, include shared-memory tests verifying:
- and/or/xor over dim=1 with clear=True/False
- integer dtypes (e.g., int32, uint32)
I can draft minimal builders mirroring
reduce_*_ss.Also applies to: 145-163, 186-204, 206-224
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
src/op/reduce.cc(1 hunks)src/tl_templates/cuda/reduce.h(1 hunks)src/tl_templates/hip/reduce.h(1 hunks)testing/python/language/test_tilelang_language_reduce.py(1 hunks)testing/python/language/test_tilelang_language_reduce_max.py(0 hunks)testing/python/language/test_tilelang_language_reduce_sum.py(0 hunks)tilelang/jit/adapter/wrapper.py(1 hunks)
💤 Files with no reviewable changes (2)
- testing/python/language/test_tilelang_language_reduce_sum.py
- testing/python/language/test_tilelang_language_reduce_max.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/tl_templates/cuda/reduce.h (1)
src/tl_templates/hip/reduce.h (1)
void(46-87)
src/tl_templates/hip/reduce.h (1)
src/tl_templates/cuda/reduce.h (3)
T(208-280)void(46-87)run(140-201)
src/op/reduce.cc (3)
src/layout/utils.cc (10)
CompressIterator(200-244)CompressIterator(200-203)iter_sum(174-186)iter_sum(174-174)ToVMap(254-260)ToVMap(254-254)mark(57-67)mark(57-57)mark(188-194)mark(188-188)src/target/utils.cc (8)
TargetIsHopper(49-54)TargetIsHopper(49-49)TargetIsSm100(56-61)TargetIsSm100(56-56)TargetIsCuda(11-13)TargetIsCuda(11-11)TargetIsRocm(14-16)TargetIsRocm(14-14)src/transform/loop_partition.cc (2)
PartitionLoop(61-163)PartitionLoop(61-62)
testing/python/language/test_tilelang_language_reduce.py (10)
src/tl_templates/cuda/reduce.h (2)
tl(5-88)T(208-280)src/tl_templates/hip/reduce.h (1)
tl(5-88)tilelang/testing/__init__.py (1)
set_random_seed(30-35)tilelang/language/allocate.py (2)
alloc_shared(21-36)alloc_fragment(53-64)tilelang/language/copy.py (1)
copy(10-86)tilelang/jit/__init__.py (1)
compile(34-90)tilelang/jit/kernel.py (2)
out_idx(462-463)get_profiler(376-392)tilelang/profiler/__init__.py (1)
assert_allclose(76-145)tilelang/language/reduce.py (5)
reduce_max(50-68)reduce_sum(87-109)reduce_min(71-84)reduce_abssum(112-124)reduce_absmax(127-139)tilelang/language/fill.py (2)
fill(9-21)clear(24-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (3)
tilelang/jit/adapter/wrapper.py (1)
13-19: Correct CUDA runtime error check.Switching to
cudaSuccessmatchescudaError_t/cuda runtime usage. Good fix.src/tl_templates/hip/reduce.h (1)
25-41: HIP reduction primitives look good.Wave-size handling, shuffle width, and bitwise functors align with CUDA parity.
Also applies to: 43-88, 90-115
src/op/reduce.cc (1)
387-392: Shared-memory reduce lowering looks consistent.
- abs handling uses
UseAbsflag- accumulate semantics map to
NeedAccumulate- dst passed with RW access_ptr
LGTM.
Also applies to: 404-418
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/reduce.cc (1)
141-176: Update outdated documentation.The comment block describes the old implementation that only supported
local.fragmentscope. With the addition of shared-memory reduction support (lines 366-424), this documentation is now outdated and misleading.Update the documentation to reflect both reduction paths and their respective behaviors.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/op/reduce.cc(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/reduce.cc (3)
src/layout/utils.cc (10)
CompressIterator(200-244)CompressIterator(200-203)iter_sum(174-186)iter_sum(174-174)ToVMap(254-260)ToVMap(254-254)mark(57-67)mark(57-57)mark(188-194)mark(188-188)src/target/utils.cc (8)
TargetIsHopper(49-54)TargetIsHopper(49-49)TargetIsSm100(56-61)TargetIsSm100(56-56)TargetIsCuda(11-13)TargetIsCuda(11-11)TargetIsRocm(14-16)TargetIsRocm(14-14)src/transform/loop_partition.cc (2)
PartitionLoop(61-163)PartitionLoop(61-62)
🔇 Additional comments (4)
src/op/reduce.cc (4)
178-185: LGTM!The
get_bufferhelper cleanly handles buffer remapping, and extracting scopes upfront enables clear branching logic between local.fragment and shared-memory paths.
341-359: LGTM!The final assembly correctly wraps the reduction body in parallel loops, applies loop partitioning for thread-level parallelism, and handles the scalar case with an appropriate thread guard. The conditional allocation of the clear buffer is also properly placed.
398-407: LGTM!The extent calculations correctly compute the reduction dimension size, the tail extent for dimensions beyond the reduction axis (defaulting to 1 when reducing the last dimension), and the total destination size. These values properly parameterize the
SharedReduceWarpcall.
426-428: LGTM!The final error handler provides a clear, informative message for unsupported scope combinations, making debugging easier when an unexpected case is encountered.
| ICHECK(is_one(dst_layout->OutputShape().back())) | ||
| << "Reduce for scalar not implemented."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify confusing error message.
The assertion checks that dst_layout->OutputShape().back() is 1, but the error message states "Reduce for scalar not implemented." This is misleading: if the output shape is not 1, then we're not reducing to a scalar, so the error message doesn't match the failure condition.
Consider a clearer message like:
ICHECK(is_one(dst_layout->OutputShape().back()))
<< "Expected scalar output (extent=1) for 1D reduction, got extent="
<< dst_layout->OutputShape().back();🤖 Prompt for AI Agents
In src/op/reduce.cc around lines 198 to 199, the ICHECK asserts that
dst_layout->OutputShape().back() is 1 but the error message "Reduce for scalar
not implemented." is misleading; change the message to explicitly state the
expected extent and the actual value so it reflects the failure condition (e.g.,
"Expected scalar output (extent=1) for 1D reduction, got extent=<value>").
Update the ICHECK message to include dst_layout->OutputShape().back() so the log
shows the actual extent when the check fails.
| if (need_duplicate) { | ||
| PrimExpr src_val = BufferLoad(clear_buffer, dst_indices); | ||
| PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices); | ||
| PrimExpr update; | ||
| if (this->type->isSum() || this->type->isAbsSum()) { | ||
| update = dst_val + src_val; | ||
| } else if (this->type->isBitAnd()) { | ||
| update = this->clear ? src_val : bitwise_and(dst_val, src_val); | ||
| } else if (this->type->isBitOr()) { | ||
| update = bitwise_or(dst_val, src_val); | ||
| } else if (this->type->isBitXor()) { | ||
| update = bitwise_xor(dst_val, src_val); | ||
| } else { | ||
| stmts.push_back(BufferStore( | ||
| dst_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices)); | ||
| LOG(FATAL) << "Unsupported reduce type: " << this->type->type; | ||
| } | ||
| } else if (this->type->isBitOr()) { | ||
| stmts.push_back( | ||
| BufferStore(dst_buffer, | ||
| bitwise_or(BufferLoad(dst_buffer, dst_indices), | ||
| BufferLoad(clear_buffer, dst_indices)), | ||
| dst_indices)); | ||
| } else if (this->type->isBitXor()) { | ||
| stmts.push_back( | ||
| BufferStore(dst_buffer, | ||
| bitwise_xor(BufferLoad(dst_buffer, dst_indices), | ||
| BufferLoad(clear_buffer, dst_indices)), | ||
| dst_indices)); | ||
| stmts.push_back(BufferStore(dst_buffer, update, dst_indices)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove dead code in BitAnd ternary check.
At line 330, the code checks this->clear ? src_val : bitwise_and(dst_val, src_val). However, this entire block (lines 323-339) only executes when need_duplicate is true. For BitAnd, need_duplicate is set to true only when !this->clear (lines 237-238). Therefore, when reaching line 330, this->clear is always false, making the true branch of the ternary dead code.
Apply this diff to simplify:
} else if (this->type->isBitAnd()) {
- update = this->clear ? src_val : bitwise_and(dst_val, src_val);
+ update = bitwise_and(dst_val, src_val);
} else if (this->type->isBitOr()) {📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (need_duplicate) { | |
| PrimExpr src_val = BufferLoad(clear_buffer, dst_indices); | |
| PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices); | |
| PrimExpr update; | |
| if (this->type->isSum() || this->type->isAbsSum()) { | |
| update = dst_val + src_val; | |
| } else if (this->type->isBitAnd()) { | |
| update = this->clear ? src_val : bitwise_and(dst_val, src_val); | |
| } else if (this->type->isBitOr()) { | |
| update = bitwise_or(dst_val, src_val); | |
| } else if (this->type->isBitXor()) { | |
| update = bitwise_xor(dst_val, src_val); | |
| } else { | |
| stmts.push_back(BufferStore( | |
| dst_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices)); | |
| LOG(FATAL) << "Unsupported reduce type: " << this->type->type; | |
| } | |
| } else if (this->type->isBitOr()) { | |
| stmts.push_back( | |
| BufferStore(dst_buffer, | |
| bitwise_or(BufferLoad(dst_buffer, dst_indices), | |
| BufferLoad(clear_buffer, dst_indices)), | |
| dst_indices)); | |
| } else if (this->type->isBitXor()) { | |
| stmts.push_back( | |
| BufferStore(dst_buffer, | |
| bitwise_xor(BufferLoad(dst_buffer, dst_indices), | |
| BufferLoad(clear_buffer, dst_indices)), | |
| dst_indices)); | |
| stmts.push_back(BufferStore(dst_buffer, update, dst_indices)); | |
| } | |
| if (need_duplicate) { | |
| PrimExpr src_val = BufferLoad(clear_buffer, dst_indices); | |
| PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices); | |
| PrimExpr update; | |
| if (this->type->isSum() || this->type->isAbsSum()) { | |
| update = dst_val + src_val; | |
| } else if (this->type->isBitAnd()) { | |
| update = bitwise_and(dst_val, src_val); | |
| } else if (this->type->isBitOr()) { | |
| update = bitwise_or(dst_val, src_val); | |
| } else if (this->type->isBitXor()) { | |
| update = bitwise_xor(dst_val, src_val); | |
| } else { | |
| LOG(FATAL) << "Unsupported reduce type: " << this->type->type; | |
| } | |
| stmts.push_back(BufferStore(dst_buffer, update, dst_indices)); | |
| } |
🤖 Prompt for AI Agents
In src/op/reduce.cc around lines 323 to 339, the BitAnd branch uses a ternary
this->clear ? src_val : bitwise_and(dst_val, src_val) but need_duplicate is only
true when !this->clear, so the true branch is dead; simplify by replacing the
ternary with the direct expression bitwise_and(dst_val, src_val) for the BitAnd
case (remove the unreachable branch and keep the update assignment using
bitwise_and).
| if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) { | ||
| Buffer src_buffer = get_buffer(this->src); | ||
| Buffer dst_buffer = get_buffer(this->dst); | ||
|
|
||
| size_t src_dim = src_buffer->shape.size(); | ||
| size_t dst_dim = dst_buffer->shape.size(); | ||
| bool is_1d_reduce = (src_dim == dst_dim && dst_dim == 1); | ||
| if (!is_1d_reduce) { | ||
| ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch."; | ||
| } else { | ||
| ICHECK_EQ(dst_dim, 1U) << "Expect scalar layout for 1D reduce."; | ||
| } | ||
|
|
||
| auto thread_extent = as_const_int(T.thread_bounds->extent); | ||
| ICHECK(thread_extent) | ||
| << "Shared-memory reduce requires static thread extent."; | ||
| int threads = *thread_extent; | ||
|
|
||
| if (TargetIsCuda(T.target)) { | ||
| ICHECK_EQ(threads % 32, 0) | ||
| << "Shared reduce expects blockDim.x to be a multiple of 32 on CUDA."; | ||
| } else if (TargetIsRocm(T.target)) { | ||
| ICHECK_EQ(threads % 64, 0) | ||
| << "Shared reduce expects blockDim.x to be a multiple of 64 on HIP."; | ||
| } | ||
|
|
||
| bool use_abs = this->type->isAbsSum() || this->type->isAbsMax(); | ||
| bool need_accumulate = | ||
| (!this->clear) && (this->type->isSum() || this->type->isAbsSum() || | ||
| this->type->isBitAnd() || this->type->isBitOr() || | ||
| this->type->isBitXor()); | ||
|
|
||
| PrimExpr reduce_extent = src_buffer->shape[this->dim]; | ||
| PrimExpr tail_extent = make_const(DataType::Int(32), 1); | ||
| for (size_t i = this->dim + 1; i < src_dim; ++i) { | ||
| tail_extent = analyzer->Simplify(tail_extent * src_buffer->shape[i]); | ||
| } | ||
|
|
||
| PrimExpr total_dest = make_const(DataType::Int(32), 1); | ||
| for (size_t i = 0; i < dst_dim; ++i) { | ||
| total_dest = analyzer->Simplify(total_dest * dst_buffer->shape[i]); | ||
| } | ||
|
|
||
| std::stringstream ss; | ||
| std::string reducer = this->MakeCodegenReducer(); | ||
| ss << "tl::SharedReduceWarp<" << reducer << ", " << threads << ", " | ||
| << (use_abs ? "true" : "false") << ", " | ||
| << (need_accumulate ? "true" : "false") << ">::run"; | ||
|
|
||
| Array<PrimExpr> call_args = {StringImm(ss.str()), | ||
| src_buffer.access_ptr(1), | ||
| dst_buffer.access_ptr(3), | ||
| cast(DataType::Int(32), total_dest), | ||
| cast(DataType::Int(32), reduce_extent), | ||
| cast(DataType::Int(32), tail_extent), | ||
| this->MakeInitValue()}; | ||
|
|
||
| return Evaluate(Call(dst_buffer->dtype, builtin::call_extern(), call_args)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Add target validation for shared-memory reduction path.
The shared-memory reduction path generates a tl::SharedReduceWarp call that is only available for CUDA and HIP targets (per the PR description and template files). However, lines 384-390 only validate thread extents for CUDA/HIP without ensuring the target is one of these platforms. If the target is neither CUDA nor HIP, the code will still generate the call, potentially causing compilation failures.
Add target validation at the beginning of the shared-memory path:
if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) {
+ ICHECK(TargetIsCuda(T.target) || TargetIsRocm(T.target))
+ << "Shared-memory reduction is only supported on CUDA and HIP targets, got: "
+ << T.target;
+
Buffer src_buffer = get_buffer(this->src);Verify: Does shared-memory reduction require loop wrapping?
The shared-memory reduction path returns a bare Evaluate(Call(...)) statement (line 423) without any loop wrapping, thread guards, or partitioning, unlike the local.fragment path (lines 341-359). This suggests the SharedReduceWarp template internally handles all thread coordination and iteration.
Please confirm this is the intended behavior. If the template handles everything internally, consider adding a comment explaining why no loop wrapping is needed, for example:
// SharedReduceWarp template handles all thread coordination and iteration internally
return Evaluate(Call(dst_buffer->dtype, builtin::call_extern(), call_args));I'll verify the review comment by examining the codebase. Let me search for the specific code and patterns mentioned.
🏁 Script executed:
#!/bin/bash
# Verify the code snippet and check target validation patterns
# 1. Check if src/op/reduce.cc exists and view the specific lines
echo "=== Checking src/op/reduce.cc lines 366-424 ==="
if [ -f "src/op/reduce.cc" ]; then
sed -n '366,424p' src/op/reduce.cc
else
echo "File not found, searching for reduce.cc..."
fd -t f -i "reduce\.cc"
fiLength of output: 2585
🏁 Script executed:
#!/bin/bash
# Look for TargetIsCuda and TargetIsRocm usage patterns
echo "=== Searching for TargetIsCuda/TargetIsRocm usage patterns ==="
rg -n "TargetIsCuda|TargetIsRocm" -C 5 --type cc --type cpp | head -60Length of output: 89
🏁 Script executed:
#!/bin/bash
# Check for SharedReduceWarp template references
echo "=== Searching for SharedReduceWarp references ==="
rg -rn "SharedReduceWarp" --type cc --type cpp --type hLength of output: 76
🏁 Script executed:
#!/bin/bash
# Look at the local.fragment path mentioned (lines 341-359) for comparison
echo "=== Checking local.fragment reduction path (lines 341-359) ==="
if [ -f "src/op/reduce.cc" ]; then
sed -n '341,359p' src/op/reduce.cc
fiLength of output: 764
🏁 Script executed:
#!/bin/bash
# Check for other target validation patterns in reduce.cc
echo "=== Searching for target validation patterns in reduce.cc ==="
rg -n "TargetIs" src/op/reduce.cc -B 2 -A 2Length of output: 1073
I need to check the broader context to understand if there's a guard before this shared-memory path, and verify target-specific patterns.
🏁 Script executed:
#!/bin/bash
# Check broader context around the shared-memory reduction (lines 350-430)
echo "=== Full context around shared-memory reduction path ==="
sed -n '320,430p' src/op/reduce.ccLength of output: 4312
🏁 Script executed:
#!/bin/bash
# Search for SharedReduceWarp references without type flags
echo "=== Searching for SharedReduceWarp ==="
rg -rn "SharedReduceWarp"Length of output: 380
🏁 Script executed:
#!/bin/bash
# Check if there's a guard before the shared-memory path
echo "=== Looking for conditional guards before line 366 ==="
sed -n '330,370p' src/op/reduce.ccLength of output: 1508
🏁 Script executed:
#!/bin/bash
# Search for TargetIsCuda/TargetIsRocm patterns in reduce.cc without type flags
echo "=== All TargetIsCuda/TargetIsRocm usage in reduce.cc ==="
rg -n "TargetIsCuda\|TargetIsRocm" src/op/reduce.cc -B 3 -A 3Length of output: 56
🏁 Script executed:
#!/bin/bash
# Look for where this reduction code is called from (find method signature)
echo "=== Finding method containing line 366 ==="
sed -n '280,370p' src/op/reduce.cc | grep -n "if\|for\|void\|Stmt\|PrimExpr"Length of output: 1270
I need to examine the template implementations to verify whether loop wrapping is necessary.
🏁 Script executed:
#!/bin/bash
# View CUDA reduce template implementation
echo "=== CUDA reduce.h template ==="
head -100 src/tl_templates/cuda/reduce.hLength of output: 3101
🏁 Script executed:
#!/bin/bash
# View HIP reduce template implementation
echo "=== HIP reduce.h template ==="
head -100 src/tl_templates/hip/reduce.hLength of output: 2994
🏁 Script executed:
#!/bin/bash
# Search for "run" function/method in templates to understand call signature
echo "=== Searching for 'run' method in templates ==="
rg -n "run" src/tl_templates/cuda/reduce.h src/tl_templates/hip/reduce.h -A 5Length of output: 4315
🏁 Script executed:
#!/bin/bash
# Check full method signature containing the shared-memory path
echo "=== Full method containing shared-memory reduction ==="
sed -n '100,200p' src/op/reduce.cc | head -50Length of output: 1615
🏁 Script executed:
#!/bin/bash
# Search for method signature/function name
echo "=== Looking for Stmt or PrimExpr return methods around line 366 ==="
sed -n '100,370p' src/op/reduce.cc | grep -n "Stmt\|PrimExpr" | tail -20Length of output: 1165
Add target validation guard at the beginning of the shared-memory reduction path.
The shared-memory reduction path generates a tl::SharedReduceWarp template call that is only available for CUDA and HIP targets. However, the code (lines 384-389) checks thread extents for CUDA/HIP without guarding against other targets. If the target is neither CUDA nor HIP, the code will still generate the template call, causing compilation to fail since templates only exist in src/tl_templates/cuda/reduce.h and src/tl_templates/hip/reduce.h.
if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) {
+ ICHECK(TargetIsCuda(T.target) || TargetIsRocm(T.target))
+ << "Shared-memory reduction is only supported on CUDA and HIP targets, got: "
+ << T.target;
+
Buffer src_buffer = get_buffer(this->src);Regarding loop wrapping: The SharedReduceWarp template handles all thread coordination and iteration internally via device-side for-loops, so the bare Evaluate(Call(...)) (line 423) is correct. This differs from the local.fragment path only because that path operates at the TIR level with host-side code, whereas the shared-memory path invokes an external device kernel.
|
@codex review |
|
Codex Review: Didn't find any major issues. 🚀 ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
…e-ai#1080) * Support reduce ss * lint fix * test fix * lint fix
This pull request adds support for shared-memory reduction operations in both CUDA and HIP backends, refactors the reduction lowering logic to handle different memory scopes, and introduces comprehensive tests for shared-memory reductions. The main focus is on enabling and validating efficient reductions (sum, max, min, abs-sum, abs-max, bitwise ops) in shared memory, with backend-specific optimizations for CUDA (warp size 32) and HIP (wave size 64).
Shared-memory reduction support
SharedReduceWarptemplate in bothsrc/tl_templates/cuda/reduce.handsrc/tl_templates/hip/reduce.hfor efficient warp/wave-level reductions, supporting sum, max, min, abs-sum, abs-max, bitwise and/or/xor, and accumulation logic. [1] [2]__shfl_down_sync,__activemask) and HIP (wave size 64,__shfl_down). [1] [2]Reduction lowering logic refactor
ReduceOpNode::Lowerinsrc/op/reduce.ccto support shared-memory reductions, including scope checks, buffer handling, dimension checks, and backend-specific constraints. Now generates shared-memory reduction calls for CUDA/HIP, and falls back to thread-local reduction for fragments. [1] [2] [3]Bitwise reduction support
BitAndOp,BitOrOp, andBitXorOpfunctors to HIP backend for completeness and parity with CUDA.Comprehensive testing
testing/python/language/test_tilelang_language_reduce.pywith extensive tests for shared-memory reductions (sum, max, min, abs-sum, abs-max), including correctness checks, clear/accumulate semantics, and backend validation.Error handling and code quality
These changes collectively enable efficient, robust, and well-tested shared-memory reductions across CUDA and HIP backends.
Summary by CodeRabbit
New Features
Bug Fixes
Tests