From b98596ef44cb0ef67cadcd32bfd250c03dd18f3b Mon Sep 17 00:00:00 2001 From: cherichy Date: Wed, 7 May 2025 00:36:38 -0700 Subject: [PATCH 1/2] fix get_swizzle_layout implementation. --- tilelang/intrinsics/mma_layout.py | 56 +++++++++++-------------------- 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index 32faae04a..ac84bb8db 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -92,25 +92,17 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j): return (i * 2 + j // 16, j % 16) -def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): +def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str], swizzle_bytes = None): ana = arith.Analyzer() - BANK_SIZE_BYTES = 128 if isinstance(dtype, str): dtype = DataType(dtype) - col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % ( - BANK_SIZE_BYTES // dtype.bits) - # use transaction bits to support diverse dtype. - # for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits - # for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits - coalescent_bits = dtype.bits * row_size - # permutation on 4 banks, each bank has 32 bits - bank_elems = BANK_SIZE_BYTES // dtype.bits - new_col_idx_outer = None - - if coalescent_bits % 1024 == 0: + row_bytes = dtype.bits * row_size // 8 + assert row_bytes % 32 == 0, f"Row size must be multiple of 32B." + if swizzle_bytes is None: + swizzle_bytes = min(128, row_bytes) + # 128B swizzle # Use 8 * 8 permuted layout - # Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read - # Every row below corresponds to 32 banks + # Every number below corresponds to 16B # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 # 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5 @@ -119,33 +111,23 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): # 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 # 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 # 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 - row_idx_sub = row_idx % bank_elems - new_col_idx_outer = col_idx_outer ^ row_idx_sub - else: - assert coalescent_bits % 512 == 0 + # 64B swizzle # Use 8 * 4 permuted layout - # Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read - # Every row below corresponds to 16 banks - # 0 1 2 3 ==> 0 1 2 3 - # 0 1 2 3 ==> 0 1 2 3 - # 0 1 2 3 ==> 1 0 3 2 - # 0 1 2 3 ==> 1 0 3 2 - # 0 1 2 3 ==> 2 3 0 1 - # 0 1 2 3 ==> 2 3 0 1 - # 0 1 2 3 ==> 3 2 1 0 - # 0 1 2 3 ==> 3 2 1 0 - # View with 8 elements per row: + # Every number below corresponds to 16B # 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3 # 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2 # 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1 # 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0 - row_idx_sub = row_idx % bank_elems - # Interleave elems per byte - interleave_elems = 32 // dtype.bits - new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems) - - assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits" - return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner) + # 32B swizzle + # Use 8 * 2 permuted layout + # Every number below corresponds to 16B + # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 + # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 + elem_per_16B = 128 // dtype.bits + col_idx_16B = col_idx // elem_per_16B + col_idx_in_16B = col_idx % elem_per_16B + new_col_idx_16B = col_idx_16B ^ (row_idx % (swizzle_bytes // 16)) + return row_idx, ana.simplify(new_col_idx_16B * elem_per_16B + col_idx_in_16B) def make_mma_swizzle_layout(shared_buf, is_smooth: bool = False): From a37aaa8e8b0395b1deeb654f0cb531271ed2b281 Mon Sep 17 00:00:00 2001 From: cherichy Date: Wed, 7 May 2025 00:44:51 -0700 Subject: [PATCH 2/2] format. --- tilelang/intrinsics/mma_layout.py | 44 +++++++++++++++---------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index ac84bb8db..ba8f5cbe6 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -92,37 +92,37 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j): return (i * 2 + j // 16, j % 16) -def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str], swizzle_bytes = None): +def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str], swizzle_bytes=None): ana = arith.Analyzer() if isinstance(dtype, str): dtype = DataType(dtype) row_bytes = dtype.bits * row_size // 8 - assert row_bytes % 32 == 0, f"Row size must be multiple of 32B." + assert row_bytes % 32 == 0, "Row size must be multiple of 32B." if swizzle_bytes is None: swizzle_bytes = min(128, row_bytes) # 128B swizzle - # Use 8 * 8 permuted layout - # Every number below corresponds to 16B - # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 - # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 - # 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5 - # 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4 - # 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3 - # 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 - # 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 - # 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 + # Use 8 * 8 permuted layout + # Every number below corresponds to 16B + # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 + # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 + # 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5 + # 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4 + # 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3 + # 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 + # 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 + # 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 # 64B swizzle - # Use 8 * 4 permuted layout - # Every number below corresponds to 16B - # 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3 - # 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2 - # 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1 - # 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0 + # Use 8 * 4 permuted layout + # Every number below corresponds to 16B + # 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3 + # 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2 + # 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1 + # 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0 # 32B swizzle - # Use 8 * 2 permuted layout - # Every number below corresponds to 16B - # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 - # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 + # Use 8 * 2 permuted layout + # Every number below corresponds to 16B + # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 + # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 elem_per_16B = 128 // dtype.bits col_idx_16B = col_idx // elem_per_16B col_idx_in_16B = col_idx % elem_per_16B