diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 9d3dd8357..6bca83d8b 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -15,6 +15,7 @@ #include "../layout/utils.h" #include "../transform/loop_partition.h" +#include "tir/transforms/ir_utils.h" namespace tvm { namespace tl { @@ -124,10 +125,33 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Array stmts; + bool require_init = this->clear; + // sum op must be cleared + if (this->type == ReduceType::kSum) { + require_init = true; + } else if (this->type == ReduceType::kAbsSum) { + require_init = true; + } + + Buffer clear_buffer = dst_buffer; + bool need_duplicate = false; + if (this->type == ReduceType::kSum && !this->clear) { + need_duplicate = true; + } else if (this->type == ReduceType::kAbsSum && !this->clear) { + need_duplicate = true; + } + + if (need_duplicate) { + // Create a new buffer with same shape and dtype as dst_buffer + clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype, + dst_buffer->name + "_clear", + GetPtrStorageScope(dst_buffer->data)); + } + // make reduce-init stmt - if (this->clear) + if (require_init) stmts.push_back( - BufferStore(dst_buffer, this->MakeInitValue(), dst_indices)); + BufferStore(clear_buffer, this->MakeInitValue(), dst_indices)); // make thread-local reduce Array src_indice_compressed; @@ -141,8 +165,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { src_var_compressed.push_back(var); } Stmt reduce_local = BufferStore( - dst_buffer, - this->MakeReduce(BufferLoad(dst_buffer, dst_indices), + clear_buffer, + this->MakeReduce(BufferLoad(clear_buffer, dst_indices), BufferLoad(src_buffer, src_indice_compressed)), dst_indices); for (int i = src_layout->OutputDim() - 1; i >= 0; i--) { @@ -179,20 +203,37 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { << reducing_threads << ", " << (*scale) << ">::run"; } Array thread_reduce_args = { - StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)}; + StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)}; if (reducing_threads >= 32) { PrimExpr workspace = T.AddWorkspace( - *as_const_int(T.thread_bounds->extent), dst_buffer->dtype); + *as_const_int(T.thread_bounds->extent), clear_buffer->dtype); thread_reduce_args.push_back(workspace); } auto call = - Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args); - stmts.push_back(BufferStore(dst_buffer, call, dst_indices)); + Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args); + stmts.push_back(BufferStore(clear_buffer, call, dst_indices)); } } - Stmt reduce_interthread = - BufferStore(dst_buffer, BufferLoad(dst_buffer, dst_indices), dst_indices); + Stmt reduce_interthread = BufferStore( + clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices); + // copy clear_buffer to dst_buffer + if (need_duplicate) { + // if is reduce sum, we should add a copy from clear_buffer to dst_buffer + if (this->type == ReduceType::kSum) { + stmts.push_back(BufferStore(dst_buffer, + Add(BufferLoad(dst_buffer, dst_indices), + BufferLoad(clear_buffer, dst_indices)), + dst_indices)); + } else if (this->type == ReduceType::kAbsSum) { + stmts.push_back(BufferStore(dst_buffer, + Add(BufferLoad(dst_buffer, dst_indices), + BufferLoad(clear_buffer, dst_indices)), + dst_indices)); + } else { + ICHECK(false) << "Unsupported reduce type: " << (int)this->type; + } + } // make the outer spatial loop Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; for (int i = dst_layout->InputDim() - 1; i >= 0; i--) { @@ -201,6 +242,10 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } body = PartitionLoop(Downcast(body), T.thread_var, analyzer, dst_layout); + if (need_duplicate) { + body = Allocate(clear_buffer->data, clear_buffer->dtype, + clear_buffer->shape, const_true(), body); + } return body; } diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index 19b70a719..8e12c96db 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -389,7 +389,6 @@ class PipelinePlanner : public StmtExprMutator { // Handle trailing unassigned copy stages: // These are typically final copy operations needing post-main-stage // insertion - auto &head_pinfo = pipeline_stage_infos.at(0); int unassigned_order_elem = -1; @@ -422,7 +421,7 @@ class PipelinePlanner : public StmtExprMutator { int copy_order_min = pipeline_stage_infos.size(); int non_copy_order_max = 0; for (auto &pinfo : pipeline_stage_infos) { - if (pinfo.copy_stage) { + if (pinfo.copy_stage || pinfo.prepare_for_condition) { copy_stage_cnt++; copy_order_min = std::min(copy_order_min, pinfo.order); } else { @@ -437,7 +436,7 @@ class PipelinePlanner : public StmtExprMutator { for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning pinfo.order = (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size(); - if (!pinfo.copy_stage) + if (!pinfo.copy_stage && !pinfo.prepare_for_condition) pinfo.stage--; } } diff --git a/testing/python/language/test_tilelang_language_reduce_max.py b/testing/python/language/test_tilelang_language_reduce_max.py index 6a9156d5f..a004f816f 100644 --- a/testing/python/language/test_tilelang_language_reduce_max.py +++ b/testing/python/language/test_tilelang_language_reduce_max.py @@ -50,5 +50,46 @@ def test_reduce_max(): run_reduce_max(256, 256, "float16") +def reduce_max_test_clear(M, N, dtype="float16"): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + T.copy(A, A_local) + T.fill(B_local, -T.infinity(dtype)) + T.reduce_max(A_local, B_local, dim=1, clear=False) + T.copy(B_local, B) + + return main + + +def run_reduce_max_clear(M, N, dtype="float16"): + program = reduce_max_test_clear(M, N, dtype) + jit_kernel = tl.compile(program, out_idx=-1) + print(jit_kernel.get_kernel_source()) + + def ref_program(A): + return A.max(dim=1).values + + import torch + dummp_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda() + ref_out = ref_program(dummp_A) + tl_out = jit_kernel(dummp_A) + print(tl_out) + print(ref_out) + torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2) + + +def test_reduce_max_clear(): + run_reduce_max_clear(256, 256, "float16") + + if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_reduce_sum.py b/testing/python/language/test_tilelang_language_reduce_sum.py new file mode 100644 index 000000000..4958aab8d --- /dev/null +++ b/testing/python/language/test_tilelang_language_reduce_sum.py @@ -0,0 +1,105 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + +from tilelang import tvm as tvm +import tilelang.testing +import tilelang as tl + +tilelang.testing.set_random_seed() + + +def reduce_sum_test(M, N, dtype="float16"): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + # Copy input to local + T.copy(A, A_local) + # Perform reduce_sum operation + T.reduce_sum(A_local, B_local, dim=1) + # Copy result back + T.copy(B_local, B) + + return main + + +def run_reduce_sum(M, N, dtype="float16"): + program = reduce_sum_test(M, N, dtype) + jit_kernel = tl.compile(program, out_idx=-1) + profiler = jit_kernel.get_profiler() + + def ref_program(A): + return A.sum(dim=1) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_reduce_sum(): + # Test different sizes + run_reduce_sum(256, 256) + run_reduce_sum(512, 128) + run_reduce_sum(128, 512) + + # Test different dtypes + run_reduce_sum(256, 256, "float32") + run_reduce_sum(256, 256, "float16") + + +def reduce_sum_test_clear(M, N, dtype="float16"): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + T.copy(A, A_local) + T.fill(B_local, 1) + T.reduce_sum(A_local, B_local, dim=1, clear=False) + T.copy(B_local, B) + + return main + + +def run_reduce_sum_clear(M, N, dtype="float16"): + program = reduce_sum_test_clear(M, N, dtype) + jit_kernel = tl.compile( + program, + out_idx=-1, + pass_configs={ + "tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True, + }) + print(jit_kernel.get_kernel_source()) + + def ref_program(A): + return A.sum(dim=1) + 1 + + import torch + dummp_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda() + ref_out = ref_program(dummp_A) + tl_out = jit_kernel(dummp_A) + print(tl_out) + print(ref_out) + torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2) + + +def test_reduce_sum_clear(): + run_reduce_sum_clear(256, 256, "float32") + run_reduce_sum_clear(512, 128, "float32") + run_reduce_sum_clear(128, 512, "float32") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_annotate_device_regions.py b/testing/python/transform/test_tilelang_transform_annotate_device_regions.py index 2732c4239..c8cbd9e23 100644 --- a/testing/python/transform/test_tilelang_transform_annotate_device_regions.py +++ b/testing/python/transform/test_tilelang_transform_annotate_device_regions.py @@ -13,12 +13,12 @@ class BaseCompare(tilelang.testing.CompareBeforeAfter): class TestAnnotateThreadExtent(BaseCompare): """Annotation inserted at the "thread_extent" attribute""" - def before(A: T.Buffer(16, "float32")): + def before(A: T.Tensor(16, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) i = T.launch_thread("threadIdx.x", 16) A[i] = 0.0 - def expected(A: T.Buffer(16, "float32")): + def expected(A: T.Tensor(16, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) T.attr(T.target("cuda"), "target", 0) i = T.launch_thread("threadIdx.x", 16) @@ -28,12 +28,12 @@ def expected(A: T.Buffer(16, "float32")): class TestAnnotateDeviceScope(BaseCompare): """Annotation inserted at the "device_scope" attribute""" - def before(A: T.Buffer(1, "float32")): + def before(A: T.Tensor(1, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) T.attr(0, "device_scope", 0) A[0] = 0.0 - def expected(A: T.Buffer(1, "float32")): + def expected(A: T.Tensor(1, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) T.attr(T.target("cuda"), "target", 0) T.attr(0, "device_scope", 0) diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 83601786d..b634b77e6 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -68,18 +68,28 @@ def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True return reduce(buffer, out, "min", dim, clear) -def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int): +def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True): """Perform reduce sum on input buffer, store the result to output buffer. Args: buffer (tir.Buffer): The input buffer out (tir.Buffer): The output buffer dim (int): The dimension to perform reduce on + clear (bool, optional): If True, output buffer will be cleared before reduction. + If False, results will be accumulated on existing values. + Defaults to True. + Note: When clear=True, reduce_sum will not compute directly on the output buffer. This is because + during warp reduction, the same value would be accumulated multiple times (number of threads + in the warp). Therefore, the implementation with clear=True follows these steps: + 1. create a temp buffer with same shape and dtype as out + 2. copy out to temp buffer + 3. call reduce_sum with temp buffer and out + 4. Add temp buffer to out Returns: tir.Call: Handle to the reduction operation """ - return reduce(buffer, out, "sum", dim, True) + return reduce(buffer, out, "sum", dim, clear) def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int): @@ -96,7 +106,7 @@ def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int): return reduce(buffer, out, "abssum", dim, True) -def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int): +def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True): """Perform reduce absolute max on input buffer, store the result to output buffer. Args: @@ -107,7 +117,7 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int): Returns: tir.Call: Handle to the reduction operation """ - return reduce(buffer, out, "absmax", dim, True) + return reduce(buffer, out, "absmax", dim, clear) @macro