Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,12 +420,23 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
std::stringstream ss;
auto threads = T.thread_bounds->extent;
ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
<< (reverse ? "true" : "false") << ">::run";
Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1),
dst.access_ptr(3)};
for (int i = 0; i < src->shape.size(); i++) {
args.push_back(src->shape[i]);
Array<PrimExpr> args;
int ndim = static_cast<int>(src->shape.size());
if (ndim == 1) {
ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim "
"= 0.";
ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false")
<< ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
src->shape[0]};
} else if (ndim == 2) {
ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
<< (reverse ? "true" : "false") << ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
src->shape[0], src->shape[1]};
} else {
LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got "
<< ndim << "D.";
}
return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
} else {
Expand All @@ -446,4 +457,4 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
} // namespace tvm
68 changes: 68 additions & 0 deletions src/tl_templates/cuda/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,74 @@ struct AllReduce {
}
};

template <int threads, bool reverse = false> struct CumSum1D {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32);
template <typename T, int SEG = 32>
static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst,
int N) {
if (N <= 0)
return;

constexpr unsigned MASK = 0xffffffff;
const int tid = threadIdx.x;
const int lane = tid % SEG;

if (tid >= SEG)
return;

T carry = (T)0;

if (reverse) {
const int num_segments = (N + SEG - 1) / SEG;
for (int seg = num_segments - 1; seg >= 0; --seg) {
const int idx = seg * SEG + lane;
T val = (idx < N) ? src[idx] : (T)0;

#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off);
if (lane < SEG - off)
val += n;
}

val += carry;

if (idx < N)
dst[idx] = val;

T segSum = (T)__shfl_sync(MASK, val, 0);
if (lane == 0)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, 0);
}
} else {
const int num_segments = (N + SEG - 1) / SEG;
for (int seg = 0; seg < num_segments; ++seg) {
const int idx = seg * SEG + lane;
T val = (idx < N) ? src[idx] : (T)0;

#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off);
if (lane >= off)
val += n;
}

val += carry;

if (idx < N)
dst[idx] = val;

T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
if (lane == SEG - 1)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, SEG - 1);
}
}
}
};

template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32);
Expand Down
79 changes: 79 additions & 0 deletions testing/python/language/test_tilelang_language_cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,75 @@ def ref_program(A):
torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3)


def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"):
import tilelang.language as T

@T.prim_func
def cumsum(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared((block_N,), dtype)

T.copy(A[bx * block_N], A_shared)
T.cumsum(src=A_shared, dim=0, reverse=reverse)
T.copy(A_shared, B[bx * block_N])

return cumsum


def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"):
import tilelang.language as T

@T.prim_func
def cumsum(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared((block_N,), dtype)
A_fragment = T.alloc_fragment((block_N,), dtype)

T.copy(A[bx * block_N], A_shared)
T.copy(A_shared, A_fragment)
T.cumsum(src=A_fragment, dim=0, reverse=reverse)
T.copy(A_fragment, B[bx * block_N])

return cumsum


def run_cumsum_1d(N, block_N, reverse=False, dtype="float32", scope="smem"):
if scope == "smem":
program = cumsum_smem_test_1d(N, block_N, reverse, dtype)
elif scope == "fragment":
program = cumsum_fragment_test_1d(N, block_N, reverse, dtype)
else:
raise ValueError(f"Unknown scope {scope}")

jit_kernel = tl.compile(program, out_idx=-1)
A = torch.randn(N, dtype=getattr(torch, dtype)).cuda()

def ref_program(A):
ref_b = torch.empty_like(A)
num_blocks = (N + block_N - 1) // block_N
for j in range(num_blocks):
start = j * block_N
end = min(start + block_N, N)
chunk = A[start:end]
if reverse:
chunk = torch.flip(chunk, dims=[0])
chunk = chunk.cumsum(dim=0)
if reverse:
chunk = torch.flip(chunk, dims=[0])
ref_b[start:end] = chunk
return ref_b

tilelang_res = jit_kernel(A)
ref_res = ref_program(A)
torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3)


def test_cumsum_smem():
# Test different sizes
run_cumsum(1024, 1024, 128, 128)
Expand All @@ -92,5 +161,15 @@ def test_cumsum_fragment():
run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment")


def test_cumsum_smem_1d():
run_cumsum_1d(1024, 128)
run_cumsum_1d(1024, 128, reverse=True)


def test_cumsum_fragment_1d():
run_cumsum_1d(1024, 128, scope="fragment")
run_cumsum_1d(1024, 128, reverse=True, scope="fragment")


if __name__ == "__main__":
tilelang.testing.main()
23 changes: 23 additions & 0 deletions tilelang/language/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,29 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve

Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic.

Examples:
A 1D inclusive scan that writes the result into a separate shared-memory buffer:

>>> import tilelang.language as T
>>> @T.prim_func
... def kernel(A: T.Tensor((128,), "float32"), B: T.Tensor((128,), "float32")):
... with T.Kernel(1, threads=128):
... A_shared = T.alloc_shared((128,), "float32")
... T.copy(A, A_shared)
... T.cumsum(src=A_shared, dst=A_shared, dim=0)
... T.copy(A_shared, B)

A 2D prefix sum along the last dimension with reverse accumulation:

>>> import tilelang.language as T
>>> @T.prim_func
... def kernel2d(A: T.Tensor((64, 64), "float16"), B: T.Tensor((64, 64), "float16")):
... with T.Kernel(1, 1, threads=256):
... tile = T.alloc_shared((64, 64), "float16")
... T.copy(A, tile)
... T.cumsum(src=tile, dim=1, reverse=True)
... T.copy(tile, B)

Returns:
tir.Call: A handle to the emitted cumulative-sum operation.
"""
Expand Down
Loading