From e663246877a616881326c2cea504d2dc592f6595 Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Mon, 10 Nov 2025 19:19:10 +0800 Subject: [PATCH 1/3] [Refactor] Update ReduceOpNode to use absolute values in Max computation and remove unused shared memory reduction logic * Changed Max computation for AbsMax type to use absolute values of lhs and rhs. * Removed unused shared memory reduction logic and related checks for buffer dimensions and thread extents, simplifying the Lower method. * Added a fatal log for unsupported buffer scope reductions. --- src/op/reduce.cc | 66 +----------------------------------------------- 1 file changed, 1 insertion(+), 65 deletions(-) diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 3e31aa2f1..b6ba14a91 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -104,7 +104,7 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs, } else if (type->isMin()) { return Min(lhs, rhs); } else if (type->isAbsMax()) { - return Max(Max(lhs, rhs), -Min(lhs, rhs)); + return Max(tvm::abs(lhs), tvm::abs(rhs)); } else if (type->isBitAnd()) { return lhs & rhs; } else if (type->isBitOr()) { @@ -360,70 +360,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return body; } - auto is_shared_scope = [](const std::string &scope) { - return scope == "shared" || scope == "shared.dyn"; - }; - - if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) { - Buffer src_buffer = get_buffer(this->src); - Buffer dst_buffer = get_buffer(this->dst); - - size_t src_dim = src_buffer->shape.size(); - size_t dst_dim = dst_buffer->shape.size(); - bool is_1d_reduce = (src_dim == dst_dim && dst_dim == 1); - if (!is_1d_reduce) { - ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch."; - } else { - ICHECK_EQ(dst_dim, 1U) << "Expect scalar layout for 1D reduce."; - } - - auto thread_extent = as_const_int(T.thread_bounds->extent); - ICHECK(thread_extent) - << "Shared-memory reduce requires static thread extent."; - int threads = *thread_extent; - - if (TargetIsCuda(T.target)) { - ICHECK_EQ(threads % 32, 0) - << "Shared reduce expects blockDim.x to be a multiple of 32 on CUDA."; - } else if (TargetIsRocm(T.target)) { - ICHECK_EQ(threads % 64, 0) - << "Shared reduce expects blockDim.x to be a multiple of 64 on HIP."; - } - - bool use_abs = this->type->isAbsSum() || this->type->isAbsMax(); - bool need_accumulate = - (!this->clear) && (this->type->isSum() || this->type->isAbsSum() || - this->type->isBitAnd() || this->type->isBitOr() || - this->type->isBitXor()); - - PrimExpr reduce_extent = src_buffer->shape[this->dim]; - PrimExpr tail_extent = make_const(DataType::Int(32), 1); - for (size_t i = this->dim + 1; i < src_dim; ++i) { - tail_extent = analyzer->Simplify(tail_extent * src_buffer->shape[i]); - } - - PrimExpr total_dest = make_const(DataType::Int(32), 1); - for (size_t i = 0; i < dst_dim; ++i) { - total_dest = analyzer->Simplify(total_dest * dst_buffer->shape[i]); - } - - std::stringstream ss; - std::string reducer = this->MakeCodegenReducer(); - ss << "tl::SharedReduceWarp<" << reducer << ", " << threads << ", " - << (use_abs ? "true" : "false") << ", " - << (need_accumulate ? "true" : "false") << ">::run"; - - Array call_args = {StringImm(ss.str()), - src_buffer.access_ptr(1), - dst_buffer.access_ptr(3), - cast(DataType::Int(32), total_dest), - cast(DataType::Int(32), reduce_extent), - cast(DataType::Int(32), tail_extent), - this->MakeInitValue()}; - - return Evaluate(Call(dst_buffer->dtype, builtin::call_extern(), call_args)); - } - LOG(FATAL) << "Reduce for buffers in scope (" << src_scope << ", " << dst_scope << ") is not implemented."; return Stmt(); From 570c481a8bea307fb4d439f9e2cf26c692f67d73 Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Mon, 10 Nov 2025 20:05:17 +0800 Subject: [PATCH 2/3] reduce fix --- tilelang/language/reduce.py | 79 +++++++++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 12 deletions(-) diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 55ac2bb0d..3ebfe7558 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -2,7 +2,9 @@ from __future__ import annotations from tvm import tir -from tilelang.language import copy, macro, alloc_shared +from tilelang.language import copy, macro, alloc_shared, alloc_fragment +from tilelang.utils.language import is_shared, is_fragment +from tvm.script.ir_builder import IRBuilder def _legalize_dim(buffer: tir.Buffer, dim: int): @@ -34,17 +36,70 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea raise ValueError( f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, " f"output shape is {out.shape}, expected shapes are {expected_shapes_str}") - buffer = buffer.access_ptr("r") - out = out.access_ptr("w") - return tir.call_intrin( - "handle", - tir.op.Op.get("tl.reduce"), - buffer, - out, - reduce_type, - dim, - clear, - ) + + @macro + def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): + if is_shared(buffer) and is_shared(out): + red_frag_in = alloc_fragment(buffer.shape, buffer.dtype) + red_frag_out = alloc_fragment(out.shape, out.dtype) + + # rename buffers + IRBuilder.name(buffer.name + "_frag", red_frag_in) + IRBuilder.name(out.name + "_frag", red_frag_out) + + copy(buffer, red_frag_in) + tir.call_intrin( + "handle", + tir.op.Op.get("tl.reduce"), + red_frag_in.access_ptr("r"), + red_frag_out.access_ptr("w"), + reduce_type, + dim, + clear, + ) + copy(red_frag_out, out) + elif is_shared(buffer) and is_fragment(out): + red_frag_in = alloc_fragment(buffer.shape, buffer.dtype) + IRBuilder.name(buffer.name + "_frag", red_frag_in) + + copy(buffer, red_frag_in) + tir.call_intrin( + "handle", + tir.op.Op.get("tl.reduce"), + red_frag_in.access_ptr("r"), + out.access_ptr("w"), + reduce_type, + dim, + clear, + ) + elif is_fragment(buffer) and is_shared(out): + red_frag_out = alloc_fragment(out.shape, out.dtype) + IRBuilder.name(out.name + "_frag", red_frag_out) + + tir.call_intrin( + "handle", + tir.op.Op.get("tl.reduce"), + buffer.access_ptr("r"), + red_frag_out.access_ptr("w"), + reduce_type, + dim, + clear, + ) + copy(red_frag_out, out) + elif is_fragment(buffer) and is_fragment(out): + tir.call_intrin( + "handle", + tir.op.Op.get("tl.reduce"), + buffer.access_ptr("r"), + out.access_ptr("w"), + reduce_type, + dim, + clear, + ) + else: + raise ValueError(f"Invalid buffer scopes: {buffer.scope()} and {out.scope()}") + + return reduce_macro(buffer, out, reduce_type, dim, clear) def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): From 814eed7289dbd1633c25f6c40302809478278299 Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Mon, 10 Nov 2025 22:45:15 +0800 Subject: [PATCH 3/3] [Fix] Update type check for eval value in Builder class * Changed the type check for eval values to raise a TypeError for unsupported types, specifically excluding instances of tvm.tir.Buffer. This improves error handling and clarity in the Builder class. --- tilelang/language/v2/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 53955c4c1..d3835a8a8 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -245,7 +245,7 @@ def eval(self, val: Any): pass elif isinstance(val, tvm.tir.stmt.BufferStore): tir.buffer_store(val.buffer, val.value, val.indices, val.predicate) - else: + elif not isinstance(val, tvm.tir.Buffer): raise TypeError(f"Unsupported eval value: {val} of type {type(val)}") def ctx_for(self, it):