Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,5 +240,55 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/*
CumSum arguments:
src: input buffer
dst: output buffer
dim: dimension to cumsum
reverse: whether to cumsum in reverse order
*/
CHECK_EQ(args.size(), 4);
src = vmap[GetVarFromAccessPtr(args[0])];
dst = vmap[GetVarFromAccessPtr(args[1])];
dim = args[2].as<IntImm>().value()->value;
reverse = args[3].as<Bool>().value();
CHECK_LT(dim, static_cast<int>(src->shape.size()));
}

Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment") {
LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
"if you need this feature.";
} else if (this->src.scope() == "shared.dyn" ||
this->src.scope() == "shared") {
ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
std::stringstream ss;
auto threads = T.thread_bounds->extent - T.thread_bounds->min;
ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
<< (reverse ? "true" : "false") << ">::run";
Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1),
dst.access_ptr(3)};
for (int i = 0; i < src->shape.size(); i++) {
args.push_back(src->shape[i]);
}
return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
} else {
ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and "
<< this->dst.scope();
}

return Stmt();
}

LayoutMap CumSumOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {};
}

TIR_REGISTER_TL_OP(CumSumOp, cumsum)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
13 changes: 13 additions & 0 deletions src/op/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ class ReduceOp : public Operator {
std::string MakeCodegenReducer() const;
};

class CumSumOp : public Operator {
public:
CumSumOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();

private:
tir::Buffer src, dst;
int dim;
bool reverse;
};

} // namespace tl
} // namespace tvm

Expand Down
79 changes: 79 additions & 0 deletions src/tl_templates/cuda/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,83 @@ struct AllReduce {
}
};

template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32);
template <typename T, int SEG = 32>
static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H,
int W) {

constexpr int TILE_H = threads / SEG;
constexpr unsigned MASK = 0xffffffff;
const int num_blocks = (H + TILE_H - 1) / TILE_H;
const int tid = threadIdx.x;
const int lane = tid % 32;
const int row = tid / 32;

for (int b = 0; b < num_blocks; ++b) {
const int gRow = b * TILE_H + row;
if (gRow >= H)
return;

T carry = (T)0;

if (reverse) {
// Start from the last segment for reverse mode
for (int seg = (W + SEG - 1) / SEG - 1; seg >= 0; --seg) {
const int col = seg * SEG + lane;

const int real_row = Axis == 1 ? gRow : col;
const int real_col = Axis == 1 ? col : gRow;

T val = (col < W) ? src[real_row * W + real_col] : (T)0;

#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off);
if (lane < SEG - off)
val += n;
}

val += carry;

if (real_col < W)
dst[real_row * W + real_col] = val;

T segSum = (T)__shfl_sync(MASK, val, (T)0);
if (lane == 0)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, (T)0);
}
} else {
for (int seg = 0; seg * SEG < W; ++seg) {
const int col = seg * SEG + lane;

const int real_row = Axis == 1 ? gRow : col;
const int real_col = Axis == 1 ? col : gRow;

T val = (col < W) ? src[real_row * W + real_col] : (T)0;

#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off);
if (lane >= off)
val += n;
}

val += carry;

if (real_col < W)
dst[real_row * W + real_col] = val;

T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
if (lane == SEG - 1)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, SEG - 1);
}
}
}
}
};

} // namespace tl
96 changes: 96 additions & 0 deletions testing/python/language/test_tilelang_language_cumsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.

from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import torch


def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16"):
import tilelang.language as T

@T.prim_func
def cumsum(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)

T.copy(A[by * block_M, bx * block_N], A_shared)
T.cumsum(src=A_shared, dim=dim, reverse=reverse)
T.copy(A_shared, B[by * block_M, bx * block_N])

return cumsum


def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16"):
import tilelang.language as T

@T.prim_func
def cumsum(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
A_fragment = T.alloc_fragment((block_M, block_N), dtype)

T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(A_shared, A_fragment)
T.cumsum(src=A_fragment, dim=dim, reverse=reverse)
T.copy(A_fragment, B[by * block_M, bx * block_N])

return cumsum


def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16", scope="smem"):
if scope == "smem":
program = cumsum_smem_test(M, N, block_M, block_N, dim, reverse, dtype)
elif scope == "fragment":
program = cumsum_fragment_test(M, N, block_M, block_N, dim, reverse, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.One)

def ref_program(A):
ref_b = torch.empty_like(A)
for i in range(M // block_M):
for j in range(N // block_N):
ref_b[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = A[i * block_M:(i + 1) * block_M, j *
block_N:(j + 1) * block_N].cumsum(dim=dim)
if reverse:
ref_b[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) *
block_N] = ref_b[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N].flip(dims=[dim])
return ref_b

profiler.assert_allclose(ref_program)


def test_cumsum_smem():
# Test different sizes
run_cumsum(1024, 1024, 128, 128)
run_cumsum(1024, 1024, 128, 128, dim=1)
run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True)

# Test different dtypes
run_cumsum(256, 256, 128, 128, dtype="float32")
run_cumsum(256, 256, 128, 128, dtype="float16")


def test_cumsum_fragment():
run_cumsum(1024, 1024, 128, 128, scope="fragment")
run_cumsum(1024, 1024, 128, 128, dim=1, scope="fragment")
run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True, scope="fragment")

# Test different dtypes
run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment")
run_cumsum(256, 256, 128, 128, dtype="float16", scope="fragment")


if __name__ == "__main__":
tilelang.testing.main()
57 changes: 57 additions & 0 deletions testing/python/language/test_tilelang_language_reduce_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.

from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl

tilelang.disable_cache()


def reduce_sum_test(M, N, dtype="float16"):
import tilelang.language as T

@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)

# Copy input to local
T.copy(A, A_local)
# Perform reduce_sum operation
T.reduce_sum(A_local, B_local, dim=0)
# Copy result back
T.copy(B_local, B)

return main


def run_reduce_sum(M, N, dtype="float16"):
program = reduce_sum_test(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
print(jit_kernel.get_kernel_source())
profiler = jit_kernel.get_profiler()

def ref_program(A):
return A.sum(dim=0)

profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reduce_sum():
# Test different sizes
run_reduce_sum(256, 256)
run_reduce_sum(512, 128)
run_reduce_sum(128, 512)

# Test different dtypes
run_reduce_sum(256, 256, "float32")
run_reduce_sum(256, 256, "float16")


if __name__ == "__main__":
tilelang.testing.main()
1 change: 1 addition & 0 deletions tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
reduce_sum, # noqa: F401
reduce_abssum, # noqa: F401
reduce_absmax, # noqa: F401
cumsum, # noqa: F401
)
from .print import print # noqa: F401
from .customize import (
Expand Down
32 changes: 32 additions & 0 deletions tilelang/language/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""The language interface for tl programs."""

from tvm import tir
from typing import Optional
from tilelang.language import copy, macro, alloc_shared


def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
Expand Down Expand Up @@ -106,3 +108,33 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int):
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "absmax", dim, True)


@macro
def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -> tir.PrimExpr:
cumsum_smem = alloc_shared(src.shape, src.dtype, "shared.dyn")
copy(src, cumsum_smem)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.cumsum"),
cumsum_smem.access_ptr("r"),
cumsum_smem.access_ptr("w"),
dim,
reverse,
)
copy(cumsum_smem, dst)


def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False):
if dst is None:
dst = src
if src.scope() == "local.fragment":
return cumsum_fragment(src, dst, dim, reverse)
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.cumsum"),
src.access_ptr("r"),
dst.access_ptr("w"),
dim,
reverse,
)
4 changes: 3 additions & 1 deletion tilelang/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def torch_assert_close(
f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%)."
f"{mismatch_info}"
f"\nGreatest absolute difference: {diff.max().item()}, "
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.")
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}"
f"\nLHS: {tensor_a}"
f"\nRHS: {tensor_b}")
else:
return True
Loading