@@ -106,6 +106,35 @@ def call({}):
106106\t }}
107107"""
108108
109+ TMA_IM2COL_DESC_INIT_FUNC = """
110+ \t CUtensorMap {0};
111+ \t CUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
112+ \t cuuint32_t {0}_tensorRank= {2};
113+ \t void *{0}_globalAddress= {3};
114+ \t cuuint64_t {0}_globalDim[{2}]= {{{4}}};
115+ \t cuuint64_t {0}_globalStride[{2}]= {{{5}}};
116+ \t cuuint32_t {0}_elementStrides[{2}]= {{{6}}};
117+ \t int {0}_lowerCorner[{2} - 2]= {{{7}}};
118+ \t int {0}_upperCorner[{2} - 2]= {{{8}}};
119+ \t cuuint32_t {0}_channelsPerPixel= {9};
120+ \t cuuint32_t {0}_pixelsPerColumn= {10};
121+ \t CUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){11};
122+ \t CUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){12};
123+ \t CUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){13};
124+ \t CUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){14};
125+
126+ \t CUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)(
127+ &{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1,
128+ {0}_lowerCorner, {0}_upperCorner, {0}_channelsPerPixel, {0}_pixelsPerColumn, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);
129+
130+ \t if ({0}_result != CUDA_SUCCESS) {{
131+ \t \t std::stringstream ss;
132+ \t \t ss << "Error: Failed to initialize the TMA descriptor {0}";
133+ \t \t snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
134+ \t \t return -1;
135+ \t }}
136+ """
137+
109138TMA_DESC_INIT_FUNC_PY = """
110139\t {0}_type = cuda.bindings.driver.CUtensorMapDataType({1})
111140\t {0}_tensorRank = {2}
@@ -401,50 +430,92 @@ def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str],
401430 if len (args ) < 3 :
402431 raise ValueError (
403432 f"TMA descriptor args too short: { len (args )} elements, expected at least 3" )
404- _ , dtype , tensor_rank , globalAddress , * remaining_args = args [1 :]
433+
434+ tma_create_str , _ , dtype , tensor_rank , globalAddress , * remaining_args = args
435+
436+ is_img2col = (tma_create_str .value == "__tvm_tensormap_create_im2col" )
405437 dtype = self ._pythonic_expr (dtype )
406438 tensor_rank = int (self ._pythonic_expr (tensor_rank ))
407439
408440 # Validate tensor_rank
409441 if not isinstance (tensor_rank , int ) or tensor_rank <= 0 :
410442 raise ValueError (f"Invalid tensor_rank: { tensor_rank } . Must be a positive integer" )
411443
412- # Calculate required length for remaining_args
413- expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
414- if len (remaining_args ) < expected_args_len :
415- raise ValueError (f"Insufficient remaining args: got { len (remaining_args )} , "
416- f"expected { expected_args_len } for tensor_rank { tensor_rank } " )
417-
418- # Extract dimensions and strides using list slicing
419- global_dim = remaining_args [:tensor_rank ]
420- global_stride = remaining_args [tensor_rank :2 * tensor_rank ]
421- box_dim = remaining_args [2 * tensor_rank :3 * tensor_rank ]
422- element_strides = remaining_args [3 * tensor_rank :4 * tensor_rank ]
423-
424- global_dim = [self ._pythonic_expr (i ) for i in global_dim ]
425- global_stride = [self ._pythonic_expr (i ) for i in global_stride ]
426- box_dim = [self ._pythonic_expr (i ) for i in box_dim ]
427- element_strides = [self ._pythonic_expr (i ) for i in element_strides ]
428-
429- # Extract remaining parameters
430- try :
431- interleave , swizzle , l2Promotion , oobFill = remaining_args [4 * tensor_rank :4 *
432- tensor_rank + 4 ]
433- interleave = self ._pythonic_expr (interleave )
434- swizzle = self ._pythonic_expr (swizzle )
435- l2Promotion = self ._pythonic_expr (l2Promotion )
436- oobFill = self ._pythonic_expr (oobFill )
437- except ValueError as e :
438- raise ValueError (
439- "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
440- ) from e
444+ if not is_img2col :
445+ # Calculate required length for remaining_args
446+ expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
447+ if len (remaining_args ) < expected_args_len :
448+ raise ValueError (f"Insufficient remaining args: got { len (remaining_args )} , "
449+ f"expected { expected_args_len } for tensor_rank { tensor_rank } " )
450+
451+ # Extract dimensions and strides using list slicing
452+ global_dim = remaining_args [:tensor_rank ]
453+ global_stride = remaining_args [tensor_rank :2 * tensor_rank ]
454+ box_dim = remaining_args [2 * tensor_rank :3 * tensor_rank ]
455+ element_strides = remaining_args [3 * tensor_rank :4 * tensor_rank ]
456+
457+ global_dim = [self ._pythonic_expr (i ) for i in global_dim ]
458+ global_stride = [self ._pythonic_expr (i ) for i in global_stride ]
459+ box_dim = [self ._pythonic_expr (i ) for i in box_dim ]
460+ element_strides = [self ._pythonic_expr (i ) for i in element_strides ]
461+
462+ # Extract remaining parameters
463+ try :
464+ interleave , swizzle , l2Promotion , oobFill = remaining_args [4 * tensor_rank :4 *
465+ tensor_rank + 4 ]
466+ interleave = self ._pythonic_expr (interleave )
467+ swizzle = self ._pythonic_expr (swizzle )
468+ l2Promotion = self ._pythonic_expr (l2Promotion )
469+ oobFill = self ._pythonic_expr (oobFill )
470+ except ValueError as e :
471+ raise ValueError (
472+ "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
473+ ) from e
474+
475+ tma_descripter_init += TMA_DESC_INIT_FUNC .format (
476+ handle_name , dtype , tensor_rank , globalAddress , "," .join (global_dim ),
477+ "," .join (global_stride ), "," .join (box_dim ), "," .join (element_strides ),
478+ interleave , swizzle , l2Promotion , oobFill )
479+ else :
480+ # Calculate required length for remaining_args
481+ expected_args_len = 5 * tensor_rank + 2
482+ if len (remaining_args ) < expected_args_len :
483+ raise ValueError (f"Insufficient remaining args: got { len (remaining_args )} , "
484+ f"expected { expected_args_len } for tensor_rank { tensor_rank } " )
485+
486+ # Extract dimensions and strides using list slicing
487+ global_dim = remaining_args [:tensor_rank ]
488+ global_stride = remaining_args [tensor_rank :2 * tensor_rank ]
489+ element_strides = remaining_args [2 * tensor_rank :3 * tensor_rank ]
490+ lower_corner = remaining_args [3 * tensor_rank :4 * tensor_rank - 2 ]
491+ upper_corner = remaining_args [4 * tensor_rank - 2 :5 * tensor_rank - 4 ]
492+ global_dim = [self ._pythonic_expr (i ) for i in global_dim ]
493+ global_stride = [self ._pythonic_expr (i ) for i in global_stride ]
494+ element_strides = [self ._pythonic_expr (i ) for i in element_strides ]
495+ lower_corner = [self ._pythonic_expr (i ) for i in lower_corner ]
496+ upper_corner = [self ._pythonic_expr (i ) for i in upper_corner ]
497+
498+ # Extract remaining parameters
499+ try :
500+ smem_box_pixel , smem_box_channel , interleave , swizzle , l2Promotion , oobFill = remaining_args [
501+ 5 * tensor_rank - 4 :5 * tensor_rank + 2 ]
502+ smem_box_pixel = self ._pythonic_expr (smem_box_pixel )
503+ smem_box_channel = self ._pythonic_expr (smem_box_channel )
504+ interleave = self ._pythonic_expr (interleave )
505+ swizzle = self ._pythonic_expr (swizzle )
506+ l2Promotion = self ._pythonic_expr (l2Promotion )
507+ oobFill = self ._pythonic_expr (oobFill )
508+ except ValueError as e :
509+ raise ValueError (
510+ "Failed to unpack the final 6 TMA parameters (smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)"
511+ ) from e
512+
513+ tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC .format (
514+ handle_name , dtype , tensor_rank , globalAddress , "," .join (global_dim ),
515+ "," .join (global_stride ), "," .join (element_strides ), "," .join (lower_corner ),
516+ "," .join (upper_corner ), smem_box_channel , smem_box_pixel , interleave , swizzle ,
517+ l2Promotion , oobFill )
441518
442- tma_descripter_init += TMA_DESC_INIT_FUNC .format (handle_name , dtype , tensor_rank ,
443- globalAddress , "," .join (global_dim ),
444- "," .join (global_stride ),
445- "," .join (box_dim ),
446- "," .join (element_strides ), interleave ,
447- swizzle , l2Promotion , oobFill )
448519 return tma_descripter_init
449520
450521 def parse_source_information (self ):
0 commit comments