@@ -85,21 +85,13 @@ def benchmark_shape(m: int,
8585
8686 # === DeepGEMM Implementation ===
8787 def deepgemm_gemm ():
88- # A quantization is inside the loop as it depends on activations
89- # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
90- # A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
91- # A, block_size[1])
92- # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
93- # C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
9488 deep_gemm .gemm_fp8_fp8_bf16_nt ((A_deepgemm , A_scale_deepgemm ),
9589 (B_deepgemm , B_scale_deepgemm ),
9690 C_deepgemm )
9791 return C_deepgemm
9892
9993 # === vLLM Triton Implementation ===
10094 def vllm_triton_gemm ():
101- # A quantization is inside the loop as it depends on activations
102- # A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
10395 return w8a8_block_fp8_matmul (A_vllm ,
10496 B_vllm ,
10597 A_scale_vllm ,
@@ -109,9 +101,6 @@ def vllm_triton_gemm():
109101
110102 # === vLLM CUTLASS Implementation ===
111103 def vllm_cutlass_gemm ():
112- # A quantization is inside the loop as it depends on activations
113- # A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
114- # A, block_size[1], column_major_scales=True)
115104 return ops .cutlass_scaled_mm (A_vllm_cutlass ,
116105 B_vllm .T ,
117106 scale_a = A_scale_vllm_cutlass ,
0 commit comments