Skip to content

Commit 5cb5c06

Browse files
authored
[Bugfix] Fix missing host cuTensorMapEncodeIm2col call (#1094)
1 parent bddb125 commit 5cb5c06

File tree

3 files changed

+113
-40
lines changed

3 files changed

+113
-40
lines changed

examples/convolution/example_convolution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def main(argv=None):
122122
out_c = kernel(a, b)
123123
ref_c = ref_program(S, P, D)(a, b)
124124
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
125+
print("All checks passed.✅")
125126

126127

127128
if __name__ == "__main__":

src/transform/inject_tma_barrier.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class TmaExpectTxRewriter : public IRMutatorWithAnalyzer {
163163
}
164164

165165
PrimExpr VisitExpr_(const CallNode *op) {
166-
if (op->op.same_as(tma_load())) {
166+
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
167167
auto arg0 = op->args[0].as<Call>();
168168
bool is_1d_tma_load =
169169
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
@@ -203,7 +203,7 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer {
203203

204204
void VisitStmt_(const EvaluateNode *op) final {
205205
if (const auto *call = op->value.as<CallNode>()) {
206-
if (call->op.same_as(tma_load())) {
206+
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
207207
pending_tma_ops_.push_back(GetRef<Call>(call));
208208
} else if (call->op.same_as(mbarrier_expect_tx())) {
209209
pending_tma_ops_.push_back(GetRef<Call>(call));
@@ -451,15 +451,16 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
451451
}
452452

453453
PrimExpr VisitExpr_(const CallNode *op) {
454-
if (op->op.same_as(tma_load())) {
454+
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
455455
// check this must be in the tma_op_to_barrier_id_
456456
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
457457
<< "tma_load must be in the tma_op_to_barrier_id_";
458458
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
459459
auto new_args = op->args;
460460
auto arg0 = op->args[0].as<Call>();
461461
auto is_1d_tma_load =
462-
arg0 && !arg0.value()->op.same_as(create_tma_descriptor());
462+
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
463+
!arg0.value()->op.same_as(create_tma_im2col_descriptor());
463464
if (is_1d_tma_load) {
464465
new_args.Set(2, barrier_id);
465466
} else {

tilelang/jit/adapter/wrapper.py

Lines changed: 107 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,35 @@ def call({}):
106106
\t}}
107107
"""
108108

109+
TMA_IM2COL_DESC_INIT_FUNC = """
110+
\tCUtensorMap {0};
111+
\tCUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
112+
\tcuuint32_t {0}_tensorRank= {2};
113+
\tvoid *{0}_globalAddress= {3};
114+
\tcuuint64_t {0}_globalDim[{2}]= {{{4}}};
115+
\tcuuint64_t {0}_globalStride[{2}]= {{{5}}};
116+
\tcuuint32_t {0}_elementStrides[{2}]= {{{6}}};
117+
\tint {0}_lowerCorner[{2} - 2]= {{{7}}};
118+
\tint {0}_upperCorner[{2} - 2]= {{{8}}};
119+
\tcuuint32_t {0}_channelsPerPixel= {9};
120+
\tcuuint32_t {0}_pixelsPerColumn= {10};
121+
\tCUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){11};
122+
\tCUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){12};
123+
\tCUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){13};
124+
\tCUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){14};
125+
126+
\tCUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)(
127+
&{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1,
128+
{0}_lowerCorner, {0}_upperCorner, {0}_channelsPerPixel, {0}_pixelsPerColumn, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);
129+
130+
\tif ({0}_result != CUDA_SUCCESS) {{
131+
\t\tstd::stringstream ss;
132+
\t\tss << "Error: Failed to initialize the TMA descriptor {0}";
133+
\t\tsnprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
134+
\t\treturn -1;
135+
\t}}
136+
"""
137+
109138
TMA_DESC_INIT_FUNC_PY = """
110139
\t{0}_type = cuda.bindings.driver.CUtensorMapDataType({1})
111140
\t{0}_tensorRank = {2}
@@ -401,50 +430,92 @@ def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str],
401430
if len(args) < 3:
402431
raise ValueError(
403432
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
404-
_, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
433+
434+
tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args
435+
436+
is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col")
405437
dtype = self._pythonic_expr(dtype)
406438
tensor_rank = int(self._pythonic_expr(tensor_rank))
407439

408440
# Validate tensor_rank
409441
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
410442
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")
411443

412-
# Calculate required length for remaining_args
413-
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
414-
if len(remaining_args) < expected_args_len:
415-
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
416-
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
417-
418-
# Extract dimensions and strides using list slicing
419-
global_dim = remaining_args[:tensor_rank]
420-
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
421-
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
422-
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
423-
424-
global_dim = [self._pythonic_expr(i) for i in global_dim]
425-
global_stride = [self._pythonic_expr(i) for i in global_stride]
426-
box_dim = [self._pythonic_expr(i) for i in box_dim]
427-
element_strides = [self._pythonic_expr(i) for i in element_strides]
428-
429-
# Extract remaining parameters
430-
try:
431-
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
432-
tensor_rank + 4]
433-
interleave = self._pythonic_expr(interleave)
434-
swizzle = self._pythonic_expr(swizzle)
435-
l2Promotion = self._pythonic_expr(l2Promotion)
436-
oobFill = self._pythonic_expr(oobFill)
437-
except ValueError as e:
438-
raise ValueError(
439-
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
440-
) from e
444+
if not is_img2col:
445+
# Calculate required length for remaining_args
446+
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
447+
if len(remaining_args) < expected_args_len:
448+
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
449+
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
450+
451+
# Extract dimensions and strides using list slicing
452+
global_dim = remaining_args[:tensor_rank]
453+
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
454+
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
455+
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
456+
457+
global_dim = [self._pythonic_expr(i) for i in global_dim]
458+
global_stride = [self._pythonic_expr(i) for i in global_stride]
459+
box_dim = [self._pythonic_expr(i) for i in box_dim]
460+
element_strides = [self._pythonic_expr(i) for i in element_strides]
461+
462+
# Extract remaining parameters
463+
try:
464+
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
465+
tensor_rank + 4]
466+
interleave = self._pythonic_expr(interleave)
467+
swizzle = self._pythonic_expr(swizzle)
468+
l2Promotion = self._pythonic_expr(l2Promotion)
469+
oobFill = self._pythonic_expr(oobFill)
470+
except ValueError as e:
471+
raise ValueError(
472+
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
473+
) from e
474+
475+
tma_descripter_init += TMA_DESC_INIT_FUNC.format(
476+
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim),
477+
",".join(global_stride), ",".join(box_dim), ",".join(element_strides),
478+
interleave, swizzle, l2Promotion, oobFill)
479+
else:
480+
# Calculate required length for remaining_args
481+
expected_args_len = 5 * tensor_rank + 2
482+
if len(remaining_args) < expected_args_len:
483+
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
484+
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
485+
486+
# Extract dimensions and strides using list slicing
487+
global_dim = remaining_args[:tensor_rank]
488+
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
489+
element_strides = remaining_args[2 * tensor_rank:3 * tensor_rank]
490+
lower_corner = remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
491+
upper_corner = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
492+
global_dim = [self._pythonic_expr(i) for i in global_dim]
493+
global_stride = [self._pythonic_expr(i) for i in global_stride]
494+
element_strides = [self._pythonic_expr(i) for i in element_strides]
495+
lower_corner = [self._pythonic_expr(i) for i in lower_corner]
496+
upper_corner = [self._pythonic_expr(i) for i in upper_corner]
497+
498+
# Extract remaining parameters
499+
try:
500+
smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[
501+
5 * tensor_rank - 4:5 * tensor_rank + 2]
502+
smem_box_pixel = self._pythonic_expr(smem_box_pixel)
503+
smem_box_channel = self._pythonic_expr(smem_box_channel)
504+
interleave = self._pythonic_expr(interleave)
505+
swizzle = self._pythonic_expr(swizzle)
506+
l2Promotion = self._pythonic_expr(l2Promotion)
507+
oobFill = self._pythonic_expr(oobFill)
508+
except ValueError as e:
509+
raise ValueError(
510+
"Failed to unpack the final 6 TMA parameters (smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)"
511+
) from e
512+
513+
tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format(
514+
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim),
515+
",".join(global_stride), ",".join(element_strides), ",".join(lower_corner),
516+
",".join(upper_corner), smem_box_channel, smem_box_pixel, interleave, swizzle,
517+
l2Promotion, oobFill)
441518

442-
tma_descripter_init += TMA_DESC_INIT_FUNC.format(handle_name, dtype, tensor_rank,
443-
globalAddress, ",".join(global_dim),
444-
",".join(global_stride),
445-
",".join(box_dim),
446-
",".join(element_strides), interleave,
447-
swizzle, l2Promotion, oobFill)
448519
return tma_descripter_init
449520

450521
def parse_source_information(self):

0 commit comments

Comments
 (0)