Skip to content

Commit 8f58e93

Browse files
committed
remove rubbish
1 parent 887638e commit 8f58e93

File tree

4 files changed

+29
-23
lines changed

4 files changed

+29
-23
lines changed

src/op/reduce.cc

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -82,27 +82,34 @@ static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
8282
int rw_mask) {
8383
Buffer buf = region->buffer;
8484
int ndim = static_cast<int>(buf->shape.size());
85-
ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";
86-
87-
// Compute row-major strides
88-
std::vector<PrimExpr> strides(ndim);
89-
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
90-
PrimExpr cur = one;
91-
for (int i = ndim - 1; i >= 0; --i) {
92-
strides[i] = cur;
93-
cur = cur * buf->shape[i];
94-
}
85+
ICHECK(ndim == 1 || ndim == 2) << "Cumsum expects buffers with 1 or 2 dims";
86+
87+
PrimExpr offset, extent;
88+
if (ndim == 1) {
89+
// Simple 1D region: offset and extent come from the single axis.
90+
auto axis = region->region[0];
91+
offset = axis->min;
92+
extent = axis->extent;
93+
} else {
94+
// Compute row-major strides for ndim >= 2
95+
std::vector<PrimExpr> strides(ndim);
96+
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
97+
PrimExpr cur = one;
98+
for (int i = ndim - 1; i >= 0; --i) {
99+
strides[i] = cur;
100+
cur = cur * buf->shape[i];
101+
}
102+
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
103+
offset = make_const(buf->shape[0].dtype(), 0);
104+
for (int i = 0; i < ndim - 2; ++i) {
105+
offset = offset + region->region[i]->min * strides[i];
106+
}
95107

96-
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
97-
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
98-
for (int i = 0; i < ndim - 2; ++i) {
99-
offset = offset + region->region[i]->min * strides[i];
108+
// Extent: last two extents product (elements)
109+
extent =
110+
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
100111
}
101112

102-
// Extent: last two extents product (elements)
103-
PrimExpr extent =
104-
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
105-
106113
// ptype and return handle
107114
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
108115
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,

testing/python/language/test_tilelang_language_atomic_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def test_atomic_addx2():
260260
run_atomic_addx2(32, 64, 8, 16)
261261

262262

263-
@tilelang.jit(debug_root_path="./testing/python/language")
263+
@tilelang.jit()
264264
def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"):
265265

266266
@T.prim_func

tilelang/analysis/ast_printer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from tvm import tir
2-
from tvm.tir import (
3-
PrimFunc,)
2+
from tvm.tir import PrimFunc
43
from tvm.tir.transform import prim_func_pass
54
from tvm.tir.stmt_functor import ir_transform
65

tilelang/language/reduce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
246246
tir.call_intrin(
247247
"handle",
248248
tir.op.Op.get("tl.cumsum"),
249-
buffer_to_tile_region(cumsum_smem),
250-
buffer_to_tile_region(cumsum_smem),
249+
buffer_to_tile_region(cumsum_smem, "r"),
250+
buffer_to_tile_region(cumsum_smem, "w"),
251251
dim,
252252
reverse,
253253
)

0 commit comments

Comments
 (0)