Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 84 additions & 7 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "../transform/loop_partition.h"
#include "region.h"
#include "tir/transforms/ir_utils.h"
#include "tvm/tir/stmt.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -57,12 +58,65 @@ static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap[var];
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}

LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg;
throw; // Unreachable
}

// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim == 1 || ndim == 2) << "Cumsum expects buffers with 1 or 2 dims";

PrimExpr offset, extent;
if (ndim == 1) {
// Simple 1D region: offset and extent come from the single axis.
auto axis = region->region[0];
offset = axis->min;
extent = axis->extent;
} else {
// Compute row-major strides for ndim >= 2
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}

// Extent: last two extents product (elements)
extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
}

// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}

ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
// Accept BufferRegion/BufferLoad/tl.region for src/dst
Expand Down Expand Up @@ -231,6 +285,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto dst_scope = this->dst.scope();

if (src_scope == "local.fragment" && dst_scope == "local.fragment") {

Buffer src_buffer = get_buffer(this->src);
Buffer dst_buffer = get_buffer(this->dst);
Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
Expand Down Expand Up @@ -518,6 +573,16 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

// Normalize "Buffer" to BufferRegion. Use the shape of the buffer as the
// ranges.
static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) {
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}

CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/// CumSum constructor arguments:
/// - src: input buffer
Expand All @@ -526,11 +591,19 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/// - reverse: whether to cumsum in reverse order
CHECK_EQ(args.size(), 4);
ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
// node->src = vmap[GetVarFromAccessPtr(args[0])];
// node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer;
node->dim = args[2].as<IntImm>().value()->value;
node->reverse = args[3].as<Bool>().value();
CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()));
CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()))
<< "The dim of cumsum should be less than the number of dimensions. Got "
"dim="
<< node->dim << ", but src has " << node->src->shape.size() << " dims.";

data_ = std::move(node);
}

Expand All @@ -546,18 +619,22 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto threads = T.thread_bounds->extent;
Array<PrimExpr> args;
int ndim = static_cast<int>(src->shape.size());

// Build access pointers from regions locally
PrimExpr srcPtr = MakeAccessPtrFromRegion(srcRegion_, 1);
PrimExpr dstPtr = MakeAccessPtrFromRegion(dstRegion_, 2);

if (ndim == 1) {
ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim "
"= 0.";
ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false")
<< ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
src->shape[0]};
args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0]};
} else if (ndim == 2) {
ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
<< (reverse ? "true" : "false") << ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
src->shape[0], src->shape[1]};
args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0],
src->shape[1]};
} else {
LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got "
<< ndim << "D.";
Expand Down
8 changes: 6 additions & 2 deletions src/op/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ class ReduceOp : public TileOperator {
class CumSumOpNode : public TileOperatorNode {
public:
tir::Buffer src, dst; ///< Source and destination buffers
int dim; ///< Dimension along which to compute cumulative sum
bool reverse; ///< Whether to compute in reverse order
// Optional: keep the original regions used to construct this op
BufferRegion srcRegion_, dstRegion_;
int dim; ///< Dimension along which to compute cumulative sum
bool reverse; ///< Whether to compute in reverse order
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode,
TileOperatorNode);

Expand All @@ -143,6 +145,8 @@ class CumSumOpNode : public TileOperatorNode {
refl::ObjectDef<CumSumOpNode>()
.def_ro("src", &CumSumOpNode::src)
.def_ro("dst", &CumSumOpNode::dst)
.def_ro("srcRegion", &CumSumOpNode::srcRegion_)
.def_ro("dstRegion", &CumSumOpNode::dstRegion_)
.def_ro("dim", &CumSumOpNode::dim)
.def_ro("reverse", &CumSumOpNode::reverse);
}
Expand Down
33 changes: 33 additions & 0 deletions testing/python/issue/test_tilelang_issue_1001.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import tilelang
import tilelang.testing
from tilelang import language as T


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _cumsum_view_infer_layout(hidden):
num_tokens = T.dynamic('num_tokens')

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']):
with T.Kernel(num_tokens, threads=128) as pid:
smem = T.alloc_shared((hidden,), dtype='float')
T.copy(x[pid, :], smem)
T.cumsum(T.view(smem, (1, hidden)), dim=1)

return buggy_kernel


def test_cumsum_view_infer_layout():
hidden = 128
x = torch.randn(1, hidden, device='cuda', dtype=torch.float)
kernel = _cumsum_view_infer_layout(hidden)
kernel(x)


if __name__ == '__main__':
tilelang.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_atomic_addx2():
run_atomic_addx2(32, 64, 8, 16)


@tilelang.jit(debug_root_path="./testing/python/language")
@tilelang.jit
def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"):

@T.prim_func
Expand Down
1 change: 1 addition & 0 deletions tilelang/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Tilelang IR analysis & visitors."""

from .ast_printer import ASTPrinter # noqa: F401
from .nested_loop_checker import NestedLoopChecker # noqa: F401
23 changes: 23 additions & 0 deletions tilelang/analysis/ast_printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from tvm import tir
from tvm.tir import PrimFunc
from tvm.tir.transform import prim_func_pass
from tvm.tir.stmt_functor import ir_transform


def ASTPrinter():
"""
Print the AST of a given tilelang module for debugging.
"""

def pre_visit(statement: tir.Stmt) -> None:
"""
Pre-order visitor to print all visited statements.
"""

print(f"Visiting statement: {type(statement)}")

def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc:
new_body = ir_transform(func.body, pre_visit, None)
return func.with_body(new_body)

return prim_func_pass(pass_fn, opt_level=0)
3 changes: 3 additions & 0 deletions tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None:
Note: This is a validation-only pipeline of passes and does not modify or return the module.
"""

# Debug
# tilelang.analysis.ASTPrinter()(mod)

# Check if there are any invalid nested loops.
tilelang.analysis.NestedLoopChecker()(mod)

Expand Down
8 changes: 4 additions & 4 deletions tilelang/language/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
tir.call_intrin(
"handle",
tir.op.Op.get("tl.cumsum"),
cumsum_smem.access_ptr("r"),
cumsum_smem.access_ptr("w"),
buffer_to_tile_region(cumsum_smem, "r"),
buffer_to_tile_region(cumsum_smem, "w"),
dim,
reverse,
)
Expand Down Expand Up @@ -300,8 +300,8 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.cumsum"),
src.access_ptr("r"),
dst.access_ptr("w"),
buffer_to_tile_region(src, "r"),
buffer_to_tile_region(dst, "w"),
dim,
reverse,
)
Expand Down
Loading