Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 55 additions & 10 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "../layout/utils.h"
#include "../transform/loop_partition.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -124,10 +125,33 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {

Array<Stmt> stmts;

bool require_init = this->clear;
// sum op must be cleared
if (this->type == ReduceType::kSum) {
require_init = true;
} else if (this->type == ReduceType::kAbsSum) {
require_init = true;
}

Buffer clear_buffer = dst_buffer;
bool need_duplicate = false;
if (this->type == ReduceType::kSum && !this->clear) {
need_duplicate = true;
} else if (this->type == ReduceType::kAbsSum && !this->clear) {
need_duplicate = true;
}

if (need_duplicate) {
// Create a new buffer with same shape and dtype as dst_buffer
clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype,
dst_buffer->name + "_clear",
GetPtrStorageScope(dst_buffer->data));
}

// make reduce-init stmt
if (this->clear)
if (require_init)
stmts.push_back(
BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));

// make thread-local reduce
Array<PrimExpr> src_indice_compressed;
Expand All @@ -141,8 +165,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
src_var_compressed.push_back(var);
}
Stmt reduce_local = BufferStore(
dst_buffer,
this->MakeReduce(BufferLoad(dst_buffer, dst_indices),
clear_buffer,
this->MakeReduce(BufferLoad(clear_buffer, dst_indices),
BufferLoad(src_buffer, src_indice_compressed)),
dst_indices);
for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
Expand Down Expand Up @@ -179,20 +203,37 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
<< reducing_threads << ", " << (*scale) << ">::run";
}
Array<PrimExpr> thread_reduce_args = {
StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
if (reducing_threads >= 32) {
PrimExpr workspace = T.AddWorkspace(
*as_const_int(T.thread_bounds->extent), dst_buffer->dtype);
*as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
thread_reduce_args.push_back(workspace);
}
auto call =
Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
stmts.push_back(BufferStore(dst_buffer, call, dst_indices));
Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args);
stmts.push_back(BufferStore(clear_buffer, call, dst_indices));
}
}
Stmt reduce_interthread =
BufferStore(dst_buffer, BufferLoad(dst_buffer, dst_indices), dst_indices);
Stmt reduce_interthread = BufferStore(
clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices);

// copy clear_buffer to dst_buffer
if (need_duplicate) {
// if is reduce sum, we should add a copy from clear_buffer to dst_buffer
if (this->type == ReduceType::kSum) {
stmts.push_back(BufferStore(dst_buffer,
Add(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else if (this->type == ReduceType::kAbsSum) {
stmts.push_back(BufferStore(dst_buffer,
Add(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else {
ICHECK(false) << "Unsupported reduce type: " << (int)this->type;
}
}
// make the outer spatial loop
Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
Expand All @@ -201,6 +242,10 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}

body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
if (need_duplicate) {
body = Allocate(clear_buffer->data, clear_buffer->dtype,
clear_buffer->shape, const_true(), body);
}
return body;
}

Expand Down
5 changes: 2 additions & 3 deletions src/transform/pipeline_planning.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ class PipelinePlanner : public StmtExprMutator {
// Handle trailing unassigned copy stages:
// These are typically final copy operations needing post-main-stage
// insertion

auto &head_pinfo = pipeline_stage_infos.at(0);
int unassigned_order_elem = -1;

Expand Down Expand Up @@ -422,7 +421,7 @@ class PipelinePlanner : public StmtExprMutator {
int copy_order_min = pipeline_stage_infos.size();
int non_copy_order_max = 0;
for (auto &pinfo : pipeline_stage_infos) {
if (pinfo.copy_stage) {
if (pinfo.copy_stage || pinfo.prepare_for_condition) {
copy_stage_cnt++;
copy_order_min = std::min(copy_order_min, pinfo.order);
} else {
Expand All @@ -437,7 +436,7 @@ class PipelinePlanner : public StmtExprMutator {
for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning
pinfo.order =
(pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
if (!pinfo.copy_stage)
if (!pinfo.copy_stage && !pinfo.prepare_for_condition)
pinfo.stage--;
}
}
Expand Down
41 changes: 41 additions & 0 deletions testing/python/language/test_tilelang_language_reduce_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,46 @@ def test_reduce_max():
run_reduce_max(256, 256, "float16")


def reduce_max_test_clear(M, N, dtype="float16"):
import tilelang.language as T

@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)

T.copy(A, A_local)
T.fill(B_local, -T.infinity(dtype))
T.reduce_max(A_local, B_local, dim=1, clear=False)
T.copy(B_local, B)

return main


def run_reduce_max_clear(M, N, dtype="float16"):
program = reduce_max_test_clear(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
print(jit_kernel.get_kernel_source())

def ref_program(A):
return A.max(dim=1).values

import torch
dummp_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
ref_out = ref_program(dummp_A)
tl_out = jit_kernel(dummp_A)
print(tl_out)
print(ref_out)
torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)


def test_reduce_max_clear():
run_reduce_max_clear(256, 256, "float16")


if __name__ == "__main__":
tilelang.testing.main()
105 changes: 105 additions & 0 deletions testing/python/language/test_tilelang_language_reduce_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.

from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl

tilelang.testing.set_random_seed()


def reduce_sum_test(M, N, dtype="float16"):
import tilelang.language as T

@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)

# Copy input to local
T.copy(A, A_local)
# Perform reduce_sum operation
T.reduce_sum(A_local, B_local, dim=1)
# Copy result back
T.copy(B_local, B)

return main


def run_reduce_sum(M, N, dtype="float16"):
program = reduce_sum_test(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler()

def ref_program(A):
return A.sum(dim=1)

profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reduce_sum():
# Test different sizes
run_reduce_sum(256, 256)
run_reduce_sum(512, 128)
run_reduce_sum(128, 512)

# Test different dtypes
run_reduce_sum(256, 256, "float32")
run_reduce_sum(256, 256, "float16")


def reduce_sum_test_clear(M, N, dtype="float16"):
import tilelang.language as T

@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)

T.copy(A, A_local)
T.fill(B_local, 1)
T.reduce_sum(A_local, B_local, dim=1, clear=False)
T.copy(B_local, B)

return main


def run_reduce_sum_clear(M, N, dtype="float16"):
program = reduce_sum_test_clear(M, N, dtype)
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True,
})
print(jit_kernel.get_kernel_source())

def ref_program(A):
return A.sum(dim=1) + 1

import torch
dummp_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
ref_out = ref_program(dummp_A)
tl_out = jit_kernel(dummp_A)
print(tl_out)
print(ref_out)
torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)


def test_reduce_sum_clear():
run_reduce_sum_clear(256, 256, "float32")
run_reduce_sum_clear(512, 128, "float32")
run_reduce_sum_clear(128, 512, "float32")


if __name__ == "__main__":
tilelang.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class BaseCompare(tilelang.testing.CompareBeforeAfter):
class TestAnnotateThreadExtent(BaseCompare):
"""Annotation inserted at the "thread_extent" attribute"""

def before(A: T.Buffer(16, "float32")):
def before(A: T.Tensor(16, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
i = T.launch_thread("threadIdx.x", 16)
A[i] = 0.0

def expected(A: T.Buffer(16, "float32")):
def expected(A: T.Tensor(16, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.attr(T.target("cuda"), "target", 0)
i = T.launch_thread("threadIdx.x", 16)
Expand All @@ -28,12 +28,12 @@ def expected(A: T.Buffer(16, "float32")):
class TestAnnotateDeviceScope(BaseCompare):
"""Annotation inserted at the "device_scope" attribute"""

def before(A: T.Buffer(1, "float32")):
def before(A: T.Tensor(1, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.attr(0, "device_scope", 0)
A[0] = 0.0

def expected(A: T.Buffer(1, "float32")):
def expected(A: T.Tensor(1, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.attr(T.target("cuda"), "target", 0)
T.attr(0, "device_scope", 0)
Expand Down
18 changes: 14 additions & 4 deletions tilelang/language/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,28 @@ def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
return reduce(buffer, out, "min", dim, clear)


def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
"""Perform reduce sum on input buffer, store the result to output buffer.

Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
clear (bool, optional): If True, output buffer will be cleared before reduction.
If False, results will be accumulated on existing values.
Defaults to True.
Note: When clear=True, reduce_sum will not compute directly on the output buffer. This is because
during warp reduction, the same value would be accumulated multiple times (number of threads
in the warp). Therefore, the implementation with clear=True follows these steps:
1. create a temp buffer with same shape and dtype as out
2. copy out to temp buffer
3. call reduce_sum with temp buffer and out
4. Add temp buffer to out

Returns:
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "sum", dim, True)
return reduce(buffer, out, "sum", dim, clear)


def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
Expand All @@ -96,7 +106,7 @@ def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
return reduce(buffer, out, "abssum", dim, True)


def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int):
def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
"""Perform reduce absolute max on input buffer, store the result to output buffer.

Args:
Expand All @@ -107,7 +117,7 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int):
Returns:
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "absmax", dim, True)
return reduce(buffer, out, "absmax", dim, clear)


@macro
Expand Down
Loading