Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 1 addition & 65 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
} else if (type->isMin()) {
return Min(lhs, rhs);
} else if (type->isAbsMax()) {
return Max(Max(lhs, rhs), -Min(lhs, rhs));
return Max(tvm::abs(lhs), tvm::abs(rhs));
} else if (type->isBitAnd()) {
return lhs & rhs;
} else if (type->isBitOr()) {
Expand Down Expand Up @@ -360,70 +360,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return body;
}

auto is_shared_scope = [](const std::string &scope) {
return scope == "shared" || scope == "shared.dyn";
};

if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) {
Buffer src_buffer = get_buffer(this->src);
Buffer dst_buffer = get_buffer(this->dst);

size_t src_dim = src_buffer->shape.size();
size_t dst_dim = dst_buffer->shape.size();
bool is_1d_reduce = (src_dim == dst_dim && dst_dim == 1);
if (!is_1d_reduce) {
ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch.";
} else {
ICHECK_EQ(dst_dim, 1U) << "Expect scalar layout for 1D reduce.";
}

auto thread_extent = as_const_int(T.thread_bounds->extent);
ICHECK(thread_extent)
<< "Shared-memory reduce requires static thread extent.";
int threads = *thread_extent;

if (TargetIsCuda(T.target)) {
ICHECK_EQ(threads % 32, 0)
<< "Shared reduce expects blockDim.x to be a multiple of 32 on CUDA.";
} else if (TargetIsRocm(T.target)) {
ICHECK_EQ(threads % 64, 0)
<< "Shared reduce expects blockDim.x to be a multiple of 64 on HIP.";
}

bool use_abs = this->type->isAbsSum() || this->type->isAbsMax();
bool need_accumulate =
(!this->clear) && (this->type->isSum() || this->type->isAbsSum() ||
this->type->isBitAnd() || this->type->isBitOr() ||
this->type->isBitXor());

PrimExpr reduce_extent = src_buffer->shape[this->dim];
PrimExpr tail_extent = make_const(DataType::Int(32), 1);
for (size_t i = this->dim + 1; i < src_dim; ++i) {
tail_extent = analyzer->Simplify(tail_extent * src_buffer->shape[i]);
}

PrimExpr total_dest = make_const(DataType::Int(32), 1);
for (size_t i = 0; i < dst_dim; ++i) {
total_dest = analyzer->Simplify(total_dest * dst_buffer->shape[i]);
}

std::stringstream ss;
std::string reducer = this->MakeCodegenReducer();
ss << "tl::SharedReduceWarp<" << reducer << ", " << threads << ", "
<< (use_abs ? "true" : "false") << ", "
<< (need_accumulate ? "true" : "false") << ">::run";

Array<PrimExpr> call_args = {StringImm(ss.str()),
src_buffer.access_ptr(1),
dst_buffer.access_ptr(3),
cast(DataType::Int(32), total_dest),
cast(DataType::Int(32), reduce_extent),
cast(DataType::Int(32), tail_extent),
this->MakeInitValue()};

return Evaluate(Call(dst_buffer->dtype, builtin::call_extern(), call_args));
}

LOG(FATAL) << "Reduce for buffers in scope (" << src_scope << ", "
<< dst_scope << ") is not implemented.";
return Stmt();
Expand Down
79 changes: 67 additions & 12 deletions tilelang/language/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from __future__ import annotations

from tvm import tir
from tilelang.language import copy, macro, alloc_shared
from tilelang.language import copy, macro, alloc_shared, alloc_fragment
from tilelang.utils.language import is_shared, is_fragment
from tvm.script.ir_builder import IRBuilder


def _legalize_dim(buffer: tir.Buffer, dim: int):
Expand Down Expand Up @@ -34,17 +36,70 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
raise ValueError(
f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, "
f"output shape is {out.shape}, expected shapes are {expected_shapes_str}")
buffer = buffer.access_ptr("r")
out = out.access_ptr("w")
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer,
out,
reduce_type,
dim,
clear,
)

@macro
def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
if is_shared(buffer) and is_shared(out):
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
red_frag_out = alloc_fragment(out.shape, out.dtype)

# rename buffers
IRBuilder.name(buffer.name + "_frag", red_frag_in)
IRBuilder.name(out.name + "_frag", red_frag_out)

copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"),
red_frag_out.access_ptr("w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
elif is_shared(buffer) and is_fragment(out):
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
IRBuilder.name(buffer.name + "_frag", red_frag_in)

copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"),
out.access_ptr("w"),
reduce_type,
dim,
clear,
)
elif is_fragment(buffer) and is_shared(out):
red_frag_out = alloc_fragment(out.shape, out.dtype)
IRBuilder.name(out.name + "_frag", red_frag_out)

tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer.access_ptr("r"),
red_frag_out.access_ptr("w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
Comment on lines +50 to +88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Preserve destination contents when clear=False.

When out lives in shared memory, we stage it through red_frag_out, but we never preload its existing values. For reduce_* calls with clear=False, that staging buffer must start from the prior accumulator; otherwise we zero out (or leave undefined) red_frag_out, call tl.reduce, and lose the intended accumulation semantics. This regresses any caller relying on clear=False.

Please seed red_frag_out from out before the intrinsic whenever clear is false, in both shared→shared and fragment→shared branches, e.g.:

         red_frag_out = alloc_fragment(out.shape, out.dtype)
         IRBuilder.name(out.name + "_frag", red_frag_out)

         copy(buffer, red_frag_in)
+        if not clear:
+            copy(out, red_frag_out)
         tir.call_intrin(
@@
         red_frag_out = alloc_fragment(out.shape, out.dtype)
         IRBuilder.name(out.name + "_frag", red_frag_out)

+        if not clear:
+            copy(out, red_frag_out)
         tir.call_intrin(
🤖 Prompt for AI Agents
In tilelang/language/reduce.py around lines 50 to 88, when staging a
shared-memory destination through red_frag_out we never preload its existing
values, so reduce calls with clear=False lose prior accumulator state; fix by,
after allocating and naming red_frag_out (both in the shared->shared branch and
the fragment->shared branch), conditionally copy(out, red_frag_out) when clear
is False before invoking tir.call_intrin so the fragment starts seeded from the
current out contents.

elif is_fragment(buffer) and is_fragment(out):
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer.access_ptr("r"),
out.access_ptr("w"),
reduce_type,
dim,
clear,
)
else:
raise ValueError(f"Invalid buffer scopes: {buffer.scope()} and {out.scope()}")

return reduce_macro(buffer, out, reduce_type, dim, clear)


def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
Expand Down
2 changes: 1 addition & 1 deletion tilelang/language/v2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def eval(self, val: Any):
pass
elif isinstance(val, tvm.tir.stmt.BufferStore):
tir.buffer_store(val.buffer, val.value, val.indices, val.predicate)
else:
elif not isinstance(val, tvm.tir.Buffer):
raise TypeError(f"Unsupported eval value: {val} of type {type(val)}")

def ctx_for(self, it):
Expand Down
Loading