@@ -82,27 +82,34 @@ static PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion,
8282 int rw_mask) {
8383 Buffer buf = region->buffer ;
8484 int ndim = static_cast <int >(buf->shape .size ());
85- ICHECK (ndim >= 2 ) << " GEMM expects buffers with at least 2 dims" ;
86-
87- // Compute row-major strides
88- std::vector<PrimExpr> strides (ndim);
89- PrimExpr one = make_const (buf->shape [0 ].dtype (), 1 );
90- PrimExpr cur = one;
91- for (int i = ndim - 1 ; i >= 0 ; --i) {
92- strides[i] = cur;
93- cur = cur * buf->shape [i];
94- }
85+ ICHECK (ndim == 1 || ndim == 2 ) << " Cumsum expects buffers with 1 or 2 dims" ;
86+
87+ PrimExpr offset, extent;
88+ if (ndim == 1 ) {
89+ // Simple 1D region: offset and extent come from the single axis.
90+ auto axis = region->region [0 ];
91+ offset = axis->min ;
92+ extent = axis->extent ;
93+ } else {
94+ // Compute row-major strides for ndim >= 2
95+ std::vector<PrimExpr> strides (ndim);
96+ PrimExpr one = make_const (buf->shape [0 ].dtype (), 1 );
97+ PrimExpr cur = one;
98+ for (int i = ndim - 1 ; i >= 0 ; --i) {
99+ strides[i] = cur;
100+ cur = cur * buf->shape [i];
101+ }
102+ // Offset: sum_{i in [0..ndim-3]} min_i * stride_i
103+ offset = make_const (buf->shape [0 ].dtype (), 0 );
104+ for (int i = 0 ; i < ndim - 2 ; ++i) {
105+ offset = offset + region->region [i]->min * strides[i];
106+ }
95107
96- // Offset: sum_{i in [0..ndim-3]} min_i * stride_i
97- PrimExpr offset = make_const (buf->shape [0 ].dtype (), 0 );
98- for (int i = 0 ; i < ndim - 2 ; ++i) {
99- offset = offset + region->region [i]->min * strides[i];
108+ // Extent: last two extents product (elements)
109+ extent =
110+ region->region [ndim - 2 ]->extent * region->region [ndim - 1 ]->extent ;
100111 }
101112
102- // Extent: last two extents product (elements)
103- PrimExpr extent =
104- region->region [ndim - 2 ]->extent * region->region [ndim - 1 ]->extent ;
105-
106113 // ptype and return handle
107114 PrimExpr ptype = tir::TypeAnnotation (buf->dtype );
108115 Array<PrimExpr> acc_args{ptype, buf->data , offset, extent,
0 commit comments