1+ from typing import Tuple
2+
3+ from .mma_layout import (
4+ mma_load_a_32x8_to_shared_16x16_layout ,
5+ mma_load_b_32x4_to_shared_16x8_layout_16bit ,
6+
7+ )
8+
9+ def mma_sp_load_a_32x8_to_shared_16x32_layout (thread_id , local_id ):
10+ return mma_load_a_32x8_to_shared_16x16_layout (thread_id , local_id )
11+
12+ def mma_sp_load_b_32x8_to_shared_32x8_layout (thread_id , local_id ):
13+ return mma_load_b_32x4_to_shared_16x8_layout_16bit (thread_id , local_id )
14+
15+ def mma_sp_load_b_32x16_to_shared_32x16_layout (thread_id , local_id ):
16+ row , col = mma_load_b_32x4_to_shared_16x8_layout_16bit (thread_id , local_id % 8 )
17+ return row , col + 8 * (local_id // 8 )
18+
19+
20+ def get_logical_id (thread_id : int ) -> int :
21+ return (thread_id // 4 ) * 2 + (thread_id % 4 ) % 2
22+
23+ def metadata_load_32x4_to_shared_16x4_layout_8bit (thread_id : int , local_id : int ) -> Tuple [int , int ]:
24+ """
25+ For 16 bit mma dtype, 8 bit mma dtype
26+ 32x4 // 2 == 16x4, For consecutive 4 threads, only 2 (lower or higher depends on selector) are needed to load metadata.
27+ Args:
28+ thread_id (int): The thread id in the warp, range [0, 31]
29+ local_id (int): The local id in the warp, range [0, 3] (u8 * 4)
30+ Returns:
31+ row (int): The row index in the shared memory
32+ """
33+ logical_id = get_logical_id (thread_id )
34+ thread_row = logical_id // 2
35+ thread_col = logical_id % 2
36+ local_row = local_id // 2
37+ local_col = local_id % 2
38+ row = thread_row + local_row * 8
39+ col = thread_col * 2 + local_col
40+ return row , col
41+
42+
43+ def metadata_load_32x2_to_shared_16x2_layout_16bit (thread_id : int , local_id : int ) -> Tuple [int , int ]:
44+ """
45+ For 16 bit mma dtype, 16 bit mma dtype
46+ 32x2 // 2 == 16x2, For consecutive 4 threads, only 2 (lower or higher depends on selector) are needed to load metadata.
47+ Args:
48+ thread_id (int): The thread id in the warp, range [0, 31]
49+ local_id (int): The local id in the warp, range [0, 1] (u16 * 2)
50+ Returns:
51+ row (int): The row index in the shared memory
52+ """
53+ logical_id = get_logical_id (thread_id )
54+ thread_row = logical_id // 2
55+ thread_col = logical_id % 2
56+ row = thread_row + local_id * 8
57+ col = thread_col
58+ return row , col
59+
60+ if __name__ == "__main__" :
61+ # for thread_id in range(32):
62+ # print(f"thread_id: {thread_id}, logical_id: {get_logical_id(thread_id)}")
63+ # for local_id in range(4):
64+ # row, col = metadata_load_32x4_to_shared_16x4_layout_8bit(thread_id, local_id)
65+ # print(f"thread_id: {thread_id}, local_id: {local_id} => row: {row}, col: {col}")
66+
67+ # for tid in range(32):
68+ # print(f"thread_id: {tid}, logical_id: {get_logical_id(tid)}")
69+ # for lid in range(2):
70+ # row, col = metadata_load_32x2_to_shared_16x2_layout_16bit(tid, lid)
71+ # print(f"thread_id: {tid}, local_id: {lid} => row: {row}, col: {col}")
72+
73+ def mma_load_b_32x4_to_shared_16x8_layout_16bit (thread_id , local_id ):
74+ """
75+ groupID = %laneid >> 2
76+ threadID_in_group = %laneid % 4
77+
78+ row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2
79+ (threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2
80+
81+ col = groupID
82+ """
83+ row = (thread_id % 4 ) * 2 + (local_id % 2 ) + (local_id // 2 ) * 8
84+ col = (thread_id // 4 )
85+ return row , col
86+
87+ def mma_load_b_32x8_to_shared_16x16_layout_16bit_replicate_b (thread_id , local_id ):
88+ row , col = mma_load_b_32x4_to_shared_16x8_layout_16bit (thread_id , local_id % 4 )
89+ return row , col + 8 * (local_id // 4 )
90+
91+ for tid in range (32 ):
92+ for lid in range (8 ):
93+ row , col = mma_load_b_32x8_to_shared_16x16_layout_16bit_replicate_b (tid , lid )
94+ print (f"thread_id: { tid } , local_id: { lid } => row: { row } , col: { col } " )
0 commit comments