|
21 | 21 |
|
22 | 22 | use_v2 = args.use_v2 |
23 | 23 |
|
| 24 | + |
24 | 25 | def get_configs(): |
25 | 26 | iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) |
26 | 27 | return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] |
@@ -73,7 +74,7 @@ def MMA0( |
73 | 74 | T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) |
74 | 75 | else: |
75 | 76 | T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) |
76 | | - |
| 77 | + |
77 | 78 | @T.macro |
78 | 79 | def MMA1( |
79 | 80 | V: T.Tensor(kv_shape, dtype), |
@@ -224,11 +225,11 @@ def main( |
224 | 225 | profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) |
225 | 226 | print("All checks pass.") |
226 | 227 | latency = profiler.do_bench(ref_program_processed, warmup=500) |
227 | | - print("Ref: {:.2f} ms".format(latency)) |
228 | | - print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) |
| 228 | + print(f"Ref: {latency:.2f} ms") |
| 229 | + print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops") |
229 | 230 | latency = profiler.do_bench(warmup=500) |
230 | | - print("Tile-lang: {:.2f} ms".format(latency)) |
231 | | - print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) |
| 231 | + print(f"Tile-lang: {latency:.2f} ms") |
| 232 | + print(f"Tile-lang: {total_flops / latency * 1e-9:.2f} TFlops") |
232 | 233 | else: |
233 | 234 | kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) |
234 | 235 | best_latency = kernel.latency |
|
0 commit comments