Skip to content

Commit 6bae64f

Browse files
authored
[Enhancement] Add support for k_pack in gemm_mfma (#1344)
* add support for k_pack * support benchmark on ROCm * fix format
1 parent 4f84400 commit 6bae64f

File tree

4 files changed

+58
-13
lines changed

4 files changed

+58
-13
lines changed

benchmark/matmul_fp8/benchmark_matmul.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import itertools
3+
import torch
34
import logging
45
import tilelang
56
import tilelang.language as T
@@ -99,6 +100,7 @@ def get_configs(args, kwargs):
99100
block_K=[64, 128],
100101
num_stages=[0, 1, 2, 3],
101102
thread_num=[128, 256],
103+
k_pack=[1, 2],
102104
policy=[T.GemmWarpPolicy.Square],
103105
enable_rasteration=[True, False],
104106
)
@@ -125,6 +127,7 @@ def matmul(
125127
block_K=None,
126128
num_stages=None,
127129
thread_num=None,
130+
k_pack=None,
128131
policy=None,
129132
enable_rasteration=None,
130133
):
@@ -156,7 +159,7 @@ def matmul(
156159

157160
# Use half-precision for input data to reduce memory bandwidth,
158161
# accumulate in float for better numerical accuracy
159-
dtype = "float8_e4m3"
162+
dtype = "float8_e4m3fnuz" if torch.version.hip is not None else "float8_e4m3"
160163
accum_dtype = "float"
161164

162165
@T.prim_func
@@ -210,6 +213,7 @@ def main(
210213
C_local,
211214
transpose_B=True,
212215
policy=policy,
216+
k_pack=k_pack,
213217
)
214218
# Write back the results from C_local to the global memory C
215219
T.copy(C_local, C_shared)

src/tl_templates/hip/hip_fp8.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,41 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
127127
res.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
128128
return res;
129129
}
130+
131+
__device__ fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2,
132+
fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5,
133+
fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t y0,
134+
fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3,
135+
fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6,
136+
fp8_e4_t y7) {
137+
signed char x0_char = *reinterpret_cast<signed char *>(&x0);
138+
signed char x1_char = *reinterpret_cast<signed char *>(&x1);
139+
signed char x2_char = *reinterpret_cast<signed char *>(&x2);
140+
signed char x3_char = *reinterpret_cast<signed char *>(&x3);
141+
signed char x4_char = *reinterpret_cast<signed char *>(&x4);
142+
signed char x5_char = *reinterpret_cast<signed char *>(&x5);
143+
signed char x6_char = *reinterpret_cast<signed char *>(&x6);
144+
signed char x7_char = *reinterpret_cast<signed char *>(&x7);
145+
signed char y0_char = *reinterpret_cast<signed char *>(&y0);
146+
signed char y1_char = *reinterpret_cast<signed char *>(&y1);
147+
signed char y2_char = *reinterpret_cast<signed char *>(&y2);
148+
signed char y3_char = *reinterpret_cast<signed char *>(&y3);
149+
signed char y4_char = *reinterpret_cast<signed char *>(&y4);
150+
signed char y5_char = *reinterpret_cast<signed char *>(&y5);
151+
signed char y6_char = *reinterpret_cast<signed char *>(&y6);
152+
signed char y7_char = *reinterpret_cast<signed char *>(&y7);
153+
int a = (x3_char << 24) | (x2_char << 16) | (x1_char << 8) | x0_char;
154+
int b = (x7_char << 24) | (x6_char << 16) | (x5_char << 8) | x4_char;
155+
int c = (y3_char << 24) | (y2_char << 16) | (y1_char << 8) | y0_char;
156+
int d = (y7_char << 24) | (y6_char << 16) | (y5_char << 8) | y4_char;
157+
fp8_e4_8_t res_x;
158+
res_x.x = *reinterpret_cast<fp8_e4_4_t *>(&a);
159+
res_x.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
160+
fp8_e4_8_t res_y;
161+
res_y.x = *reinterpret_cast<fp8_e4_4_t *>(&c);
162+
res_y.y = *reinterpret_cast<fp8_e4_4_t *>(&d);
163+
fp8_e4_16_t res;
164+
res.x = res_x;
165+
res.y = res_y;
166+
return res;
167+
}

tilelang/intrinsics/mfma_macro_generator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,8 @@ def mfma(self,
372372

373373
a_is_fragment = is_fragment(A_local_buf)
374374
b_is_fragment = is_fragment(B_local_buf)
375-
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
376-
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0
375+
a_local_stride: PrimExpr = k_inner * warp_rows * k_pack * local_size_a if a_is_fragment else 0
376+
b_local_stride: PrimExpr = k_inner * warp_cols * k_pack * local_size_b if b_is_fragment else 0
377377

378378
@T.macro
379379
def _warp_mfma(A_local_buf, B_local_buf, C_local_buf):
@@ -543,7 +543,8 @@ def forward_index(i: int, j: int) -> int:
543543
return local_id
544544

545545
base_fragment = T.Fragment(
546-
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
546+
[micro_size_s, micro_size_r *
547+
self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s],
547548
forward_thread_fn=forward_thread,
548549
forward_index_fn=forward_index,
549550
)
@@ -552,7 +553,7 @@ def forward_index(i: int, j: int) -> int:
552553
chunk = self.chunk
553554

554555
warp_s = warp_rows if matrix_is_a else warp_cols
555-
warp_r = chunk // micro_size_r
556+
warp_r = chunk // (micro_size_r * self.k_pack)
556557
block_s = block_row_warps if matrix_is_a else block_col_warps
557558
replicate = block_col_warps if matrix_is_a else block_row_warps
558559

tilelang/tileop/gemm/gemm_mfma.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def infer_layout(self, target: Target, thread_nums: int):
2828
warp_row_tiles=warp_row_tiles,
2929
warp_col_tiles=warp_col_tiles,
3030
chunk=self.chunk,
31+
k_pack=self.k_pack,
3132
)
3233

3334
if self.is_gemm_ss():
@@ -75,6 +76,7 @@ def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var:
7576
warp_col_tiles=warp_col_tiles,
7677
chunk=self.chunk,
7778
thread_var=thread_var,
79+
k_pack=self.k_pack,
7880
)
7981

8082
in_dtype = self.in_dtype
@@ -110,11 +112,11 @@ def _gemm_ssr() -> None:
110112
B_shared into local fragments, then issues Matrix Core mfma ops,
111113
accumulating into C_local.
112114
"""
113-
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
114-
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
115+
A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype)
116+
B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype)
115117
if clear_accum:
116118
T.clear(C_buf)
117-
for ki in T.serial(0, (block_K // micro_size_k)):
119+
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
118120
# Load A into fragment
119121
mfma_emitter.ldmatrix_a(
120122
A_local,
@@ -145,12 +147,12 @@ def _gemm_srr() -> None:
145147
B_shared into local fragments, then issues Matrix Core mfma ops,
146148
accumulating into C_local.
147149
"""
148-
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
150+
A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype)
149151

150152
if clear_accum:
151153
T.clear(C_buf)
152154

153-
for ki in T.serial(0, (block_K // micro_size_k)):
155+
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
154156

155157
# Load A into fragment
156158
mfma_emitter.ldmatrix_a(
@@ -177,10 +179,10 @@ def _gemm_rsr() -> None:
177179
B_shared into local fragments, then issues Matrix Core mfma ops,
178180
accumulating into C_local.
179181
"""
180-
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
182+
B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype)
181183
if clear_accum:
182184
T.clear(C_buf)
183-
for ki in T.serial(0, (block_K // micro_size_k)):
185+
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
184186

185187
# Load B into fragment
186188
mfma_emitter.ldmatrix_b(
@@ -207,7 +209,7 @@ def _gemm_rsr() -> None:
207209
accumulating into C_local.
208210
"""
209211

210-
for ki in T.serial(0, (block_K // micro_size_k)):
212+
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
211213
# Perform Matrix Multiplication
212214
mfma_emitter.mfma(A_buf, B_buf, C_buf, ki)
213215

0 commit comments

Comments
 (0)