diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 588b5d60c..9d3dd8357 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -240,5 +240,55 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +CumSumOp::CumSumOp(Array args, BufferMap vmap) { + /* + CumSum arguments: + src: input buffer + dst: output buffer + dim: dimension to cumsum + reverse: whether to cumsum in reverse order + */ + CHECK_EQ(args.size(), 4); + src = vmap[GetVarFromAccessPtr(args[0])]; + dst = vmap[GetVarFromAccessPtr(args[1])]; + dim = args[2].as().value()->value; + reverse = args[3].as().value(); + CHECK_LT(dim, static_cast(src->shape.size())); +} + +Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + if (this->src.scope() == "local.fragment" && + this->dst.scope() == "local.fragment") { + LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue " + "if you need this feature."; + } else if (this->src.scope() == "shared.dyn" || + this->src.scope() == "shared") { + ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared"); + std::stringstream ss; + auto threads = T.thread_bounds->extent - T.thread_bounds->min; + ss << "tl::CumSum2D<" << threads << ", " << dim << ", " + << (reverse ? "true" : "false") << ">::run"; + Array args = {StringImm(ss.str()), src.access_ptr(1), + dst.access_ptr(3)}; + for (int i = 0; i < src->shape.size(); i++) { + args.push_back(src->shape[i]); + } + return Evaluate(Call(dst->dtype, builtin::call_extern(), args)); + } else { + ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and " + << this->dst.scope(); + } + + return Stmt(); +} + +LayoutMap CumSumOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { + return {}; +} + +TIR_REGISTER_TL_OP(CumSumOp, cumsum) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); } // namespace tl } // namespace tvm \ No newline at end of file diff --git a/src/op/reduce.h b/src/op/reduce.h index 1679f0504..9f610ff7e 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -41,6 +41,19 @@ class ReduceOp : public Operator { std::string MakeCodegenReducer() const; }; +class CumSumOp : public Operator { +public: + CumSumOp(Array args, BufferMap vmap); + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; + static const Op &Get(); + +private: + tir::Buffer src, dst; + int dim; + bool reverse; +}; + } // namespace tl } // namespace tvm diff --git a/src/tl_templates/cuda/reduce.h b/src/tl_templates/cuda/reduce.h index f86191851..4ddf002e0 100644 --- a/src/tl_templates/cuda/reduce.h +++ b/src/tl_templates/cuda/reduce.h @@ -67,4 +67,83 @@ struct AllReduce { } }; +template struct CumSum2D { + static_assert(threads == 1024 or threads == 512 or threads == 256 or + threads == 128 or threads == 64 or threads == 32); + template + static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H, + int W) { + + constexpr int TILE_H = threads / SEG; + constexpr unsigned MASK = 0xffffffff; + const int num_blocks = (H + TILE_H - 1) / TILE_H; + const int tid = threadIdx.x; + const int lane = tid % 32; + const int row = tid / 32; + + for (int b = 0; b < num_blocks; ++b) { + const int gRow = b * TILE_H + row; + if (gRow >= H) + return; + + T carry = (T)0; + + if (reverse) { + // Start from the last segment for reverse mode + for (int seg = (W + SEG - 1) / SEG - 1; seg >= 0; --seg) { + const int col = seg * SEG + lane; + + const int real_row = Axis == 1 ? gRow : col; + const int real_col = Axis == 1 ? col : gRow; + + T val = (col < W) ? src[real_row * W + real_col] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)__shfl_down_sync(MASK, val, off); + if (lane < SEG - off) + val += n; + } + + val += carry; + + if (real_col < W) + dst[real_row * W + real_col] = val; + + T segSum = (T)__shfl_sync(MASK, val, (T)0); + if (lane == 0) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, (T)0); + } + } else { + for (int seg = 0; seg * SEG < W; ++seg) { + const int col = seg * SEG + lane; + + const int real_row = Axis == 1 ? gRow : col; + const int real_col = Axis == 1 ? col : gRow; + + T val = (col < W) ? src[real_row * W + real_col] : (T)0; + +#pragma unroll + for (int off = 1; off < SEG; off <<= 1) { + T n = (T)__shfl_up_sync(MASK, val, off); + if (lane >= off) + val += n; + } + + val += carry; + + if (real_col < W) + dst[real_row * W + real_col] = val; + + T segSum = (T)__shfl_sync(MASK, val, SEG - 1); + if (lane == SEG - 1) + carry = segSum; + carry = (T)__shfl_sync(MASK, carry, SEG - 1); + } + } + } + } +}; + } // namespace tl diff --git a/testing/python/language/test_tilelang_language_cumsum.py b/testing/python/language/test_tilelang_language_cumsum.py new file mode 100644 index 000000000..603ea2c36 --- /dev/null +++ b/testing/python/language/test_tilelang_language_cumsum.py @@ -0,0 +1,96 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + +from tilelang import tvm as tvm +import tilelang.testing +import tilelang as tl +import torch + + +def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16"): + import tilelang.language as T + + @T.prim_func + def cumsum( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[by * block_M, bx * block_N], A_shared) + T.cumsum(src=A_shared, dim=dim, reverse=reverse) + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return cumsum + + +def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16"): + import tilelang.language as T + + @T.prim_func + def cumsum( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + A_fragment = T.alloc_fragment((block_M, block_N), dtype) + + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(A_shared, A_fragment) + T.cumsum(src=A_fragment, dim=dim, reverse=reverse) + T.copy(A_fragment, B[by * block_M, bx * block_N]) + + return cumsum + + +def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16", scope="smem"): + if scope == "smem": + program = cumsum_smem_test(M, N, block_M, block_N, dim, reverse, dtype) + elif scope == "fragment": + program = cumsum_fragment_test(M, N, block_M, block_N, dim, reverse, dtype) + jit_kernel = tl.compile(program, out_idx=-1) + profiler = jit_kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.One) + + def ref_program(A): + ref_b = torch.empty_like(A) + for i in range(M // block_M): + for j in range(N // block_N): + ref_b[i * block_M:(i + 1) * block_M, + j * block_N:(j + 1) * block_N] = A[i * block_M:(i + 1) * block_M, j * + block_N:(j + 1) * block_N].cumsum(dim=dim) + if reverse: + ref_b[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * + block_N] = ref_b[i * block_M:(i + 1) * block_M, + j * block_N:(j + 1) * block_N].flip(dims=[dim]) + return ref_b + + profiler.assert_allclose(ref_program) + + +def test_cumsum_smem(): + # Test different sizes + run_cumsum(1024, 1024, 128, 128) + run_cumsum(1024, 1024, 128, 128, dim=1) + run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True) + + # Test different dtypes + run_cumsum(256, 256, 128, 128, dtype="float32") + run_cumsum(256, 256, 128, 128, dtype="float16") + + +def test_cumsum_fragment(): + run_cumsum(1024, 1024, 128, 128, scope="fragment") + run_cumsum(1024, 1024, 128, 128, dim=1, scope="fragment") + run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True, scope="fragment") + + # Test different dtypes + run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment") + run_cumsum(256, 256, 128, 128, dtype="float16", scope="fragment") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_language_reduce_sum.py b/testing/python/language/test_tilelang_language_reduce_sum.py new file mode 100644 index 000000000..d88fccac3 --- /dev/null +++ b/testing/python/language/test_tilelang_language_reduce_sum.py @@ -0,0 +1,57 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + +from tilelang import tvm as tvm +import tilelang.testing +import tilelang as tl + +tilelang.disable_cache() + + +def reduce_sum_test(M, N, dtype="float16"): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + # Copy input to local + T.copy(A, A_local) + # Perform reduce_sum operation + T.reduce_sum(A_local, B_local, dim=0) + # Copy result back + T.copy(B_local, B) + + return main + + +def run_reduce_sum(M, N, dtype="float16"): + program = reduce_sum_test(M, N, dtype) + jit_kernel = tl.compile(program, out_idx=-1) + print(jit_kernel.get_kernel_source()) + profiler = jit_kernel.get_profiler() + + def ref_program(A): + return A.sum(dim=0) + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_reduce_sum(): + # Test different sizes + run_reduce_sum(256, 256) + run_reduce_sum(512, 128) + run_reduce_sum(128, 512) + + # Test different dtypes + run_reduce_sum(256, 256, "float32") + run_reduce_sum(256, 256, "float16") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 7108059e9..e0119b5c0 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -50,6 +50,7 @@ reduce_sum, # noqa: F401 reduce_abssum, # noqa: F401 reduce_absmax, # noqa: F401 + cumsum, # noqa: F401 ) from .print import print # noqa: F401 from .customize import ( diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 9bbe3c7e8..63bfaebf5 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -3,6 +3,8 @@ """The language interface for tl programs.""" from tvm import tir +from typing import Optional +from tilelang.language import copy, macro, alloc_shared def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): @@ -106,3 +108,33 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int): tir.Call: Handle to the reduction operation """ return reduce(buffer, out, "absmax", dim, True) + + +@macro +def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -> tir.PrimExpr: + cumsum_smem = alloc_shared(src.shape, src.dtype, "shared.dyn") + copy(src, cumsum_smem) + tir.call_intrin( + "handle", + tir.op.Op.get("tl.cumsum"), + cumsum_smem.access_ptr("r"), + cumsum_smem.access_ptr("w"), + dim, + reverse, + ) + copy(cumsum_smem, dst) + + +def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False): + if dst is None: + dst = src + if src.scope() == "local.fragment": + return cumsum_fragment(src, dst, dim, reverse) + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.cumsum"), + src.access_ptr("r"), + dst.access_ptr("w"), + dim, + reverse, + ) diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py index 4f22f7b9a..1b2238986 100644 --- a/tilelang/utils/tensor.py +++ b/tilelang/utils/tensor.py @@ -295,6 +295,8 @@ def torch_assert_close( f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%)." f"{mismatch_info}" f"\nGreatest absolute difference: {diff.max().item()}, " - f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.") + f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}" + f"\nLHS: {tensor_a}" + f"\nRHS: {tensor_b}") else: return True