Skip to content

Commit 050e3fe

Browse files
committed
lint fix
1 parent 9ff1255 commit 050e3fe

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

maint/gemm_v2/latency_mha_fwd_bhsd.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
use_v2 = args.use_v2
2323

24+
2425
def get_configs():
2526
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
2627
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@@ -73,7 +74,7 @@ def MMA0(
7374
T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
7475
else:
7576
T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
76-
77+
7778
@T.macro
7879
def MMA1(
7980
V: T.Tensor(kv_shape, dtype),
@@ -224,11 +225,11 @@ def main(
224225
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
225226
print("All checks pass.")
226227
latency = profiler.do_bench(ref_program_processed, warmup=500)
227-
print("Ref: {:.2f} ms".format(latency))
228-
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
228+
print(f"Ref: {latency:.2f} ms")
229+
print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops")
229230
latency = profiler.do_bench(warmup=500)
230-
print("Tile-lang: {:.2f} ms".format(latency))
231-
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
231+
print(f"Tile-lang: {latency:.2f} ms")
232+
print(f"Tile-lang: {total_flops / latency * 1e-9:.2f} TFlops")
232233
else:
233234
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
234235
best_latency = kernel.latency

0 commit comments

Comments
 (0)