Skip to content

Commit 540a803

Browse files
committed
[gemm_sp] enable fp16/bf16 on sm8x
1 parent aab713f commit 540a803

File tree

8 files changed

+938
-49
lines changed

8 files changed

+938
-49
lines changed

tilelang/intrinsics/mma_layout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id):
156156
(threadID_in_group * 2) + (i & 0x1) + 8 for ai where i >= 4
157157
"""
158158
row = (thread_id // 4) + 8 * (local_id % 4 // 2)
159-
col = (thread_id % 4) * 2 +(local_id % 2) + 8 * (local_id // 4)
159+
col = (thread_id % 4) * 2 + (local_id % 2) + 8 * (local_id // 4)
160160
return row, col
161161

162162
def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Tuple
2+
3+
from .mma_layout import (
4+
mma_load_a_32x8_to_shared_16x16_layout,
5+
mma_load_b_32x4_to_shared_16x8_layout_16bit,
6+
7+
)
8+
9+
def mma_sp_load_a_32x8_to_shared_16x32_layout(thread_id, local_id):
10+
return mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id)
11+
12+
def mma_sp_load_b_32x8_to_shared_32x8_layout(thread_id, local_id):
13+
return mma_load_b_32x4_to_shared_16x8_layout_16bit(thread_id, local_id)
14+
15+
def mma_sp_load_b_32x16_to_shared_32x16_layout(thread_id, local_id):
16+
row, col = mma_load_b_32x4_to_shared_16x8_layout_16bit(thread_id, local_id % 8)
17+
return row, col + 8 * (local_id // 8)
18+
19+
20+
def get_logical_id(thread_id: int) -> int:
21+
return (thread_id // 4) * 2 + (thread_id % 4) % 2
22+
23+
def metadata_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> Tuple[int, int]:
24+
"""
25+
For 16 bit mma dtype, 8 bit mma dtype
26+
32x4 // 2 == 16x4, For consecutive 4 threads, only 2 (lower or higher depends on selector) are needed to load metadata.
27+
Args:
28+
thread_id (int): The thread id in the warp, range [0, 31]
29+
local_id (int): The local id in the warp, range [0, 3] (u8 * 4)
30+
Returns:
31+
row (int): The row index in the shared memory
32+
"""
33+
logical_id = get_logical_id(thread_id)
34+
thread_row = logical_id // 2
35+
thread_col = logical_id % 2
36+
local_row = local_id // 2
37+
local_col = local_id % 2
38+
row = thread_row + local_row * 8
39+
col = thread_col * 2 + local_col
40+
return row, col
41+
42+
43+
def metadata_load_32x2_to_shared_16x2_layout_16bit(thread_id: int, local_id: int) -> Tuple[int, int]:
44+
"""
45+
For 16 bit mma dtype, 16 bit mma dtype
46+
32x2 // 2 == 16x2, For consecutive 4 threads, only 2 (lower or higher depends on selector) are needed to load metadata.
47+
Args:
48+
thread_id (int): The thread id in the warp, range [0, 31]
49+
local_id (int): The local id in the warp, range [0, 1] (u16 * 2)
50+
Returns:
51+
row (int): The row index in the shared memory
52+
"""
53+
logical_id = get_logical_id(thread_id)
54+
thread_row = logical_id // 2
55+
thread_col = logical_id % 2
56+
row = thread_row + local_id * 8
57+
col = thread_col
58+
return row, col
59+
60+
if __name__ == "__main__":
61+
# for thread_id in range(32):
62+
# print(f"thread_id: {thread_id}, logical_id: {get_logical_id(thread_id)}")
63+
# for local_id in range(4):
64+
# row, col = metadata_load_32x4_to_shared_16x4_layout_8bit(thread_id, local_id)
65+
# print(f"thread_id: {thread_id}, local_id: {local_id} => row: {row}, col: {col}")
66+
67+
# for tid in range(32):
68+
# print(f"thread_id: {tid}, logical_id: {get_logical_id(tid)}")
69+
# for lid in range(2):
70+
# row, col = metadata_load_32x2_to_shared_16x2_layout_16bit(tid, lid)
71+
# print(f"thread_id: {tid}, local_id: {lid} => row: {row}, col: {col}")
72+
73+
def mma_load_b_32x4_to_shared_16x8_layout_16bit(thread_id, local_id):
74+
"""
75+
groupID = %laneid >> 2
76+
threadID_in_group = %laneid % 4
77+
78+
row = (threadID_in_group * 2) + (i & 0x1) for bi where i < 2
79+
(threadID_in_group * 2) + (i & 0x1) + 8 for bi where i >= 2
80+
81+
col = groupID
82+
"""
83+
row = (thread_id % 4) * 2 + (local_id % 2) + (local_id // 2) * 8
84+
col = (thread_id // 4)
85+
return row, col
86+
87+
def mma_load_b_32x8_to_shared_16x16_layout_16bit_replicate_b(thread_id, local_id):
88+
row, col = mma_load_b_32x4_to_shared_16x8_layout_16bit(thread_id, local_id % 4)
89+
return row, col + 8 * (local_id // 4)
90+
91+
for tid in range(32):
92+
for lid in range(8):
93+
row, col = mma_load_b_32x8_to_shared_16x16_layout_16bit_replicate_b(tid, lid)
94+
print(f"thread_id: {tid}, local_id: {lid} => row: {row}, col: {col}")

0 commit comments

Comments
 (0)