88import torch
99
1010from tests .kernels .quant_utils import (native_per_token_group_quant_fp8 ,
11- native_w8a8_block_matmul ,
12- per_block_cast_to_fp8 )
11+ native_w8a8_block_matmul )
1312from vllm .config import VllmConfig
1413from vllm .model_executor .layers .quantization .utils .fp8_utils import (
15- per_token_group_quant_fp8 , w8a8_block_fp8_matmul )
14+ get_col_major_tma_aligned_tensor , per_token_group_quant_fp8 ,
15+ w8a8_block_fp8_matmul )
1616from vllm .platforms import current_platform
17-
18- dg_available = False
19- try :
20- import deep_gemm
21- dg_available = True
22- except ImportError :
23- pass
17+ from vllm .utils import has_deep_gemm
18+ from vllm .utils .deep_gemm import (fp8_gemm_nt , per_block_cast_to_fp8 ,
19+ per_token_group_cast_to_fp8 )
2420
2521if current_platform .get_device_capability () < (9 , 0 ):
2622 pytest .skip ("FP8 Triton requires CUDA 9.0 or higher" ,
@@ -106,7 +102,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
106102@pytest .mark .parametrize (
107103 "M,N,K,block_size,out_dtype,seed" ,
108104 itertools .product (M , N , K , BLOCK_SIZE , OUT_DTYPES , SEEDS ))
109- @pytest .mark .skipif (not dg_available , reason = "DeepGemm kernels not available." )
105+ @pytest .mark .skipif (not has_deep_gemm (),
106+ reason = "DeepGemm kernels not available." )
110107@torch .inference_mode ()
111108def test_w8a8_block_fp8_deep_gemm_matmul (M , N , K , block_size , out_dtype , seed ):
112109 # only aligned sizes
@@ -120,9 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
120117 A_fp32 = (torch .rand (M , K , dtype = torch .float32 ) - 0.5 ) * 2 * fp8_max
121118 B_fp32 = (torch .rand (N , K , dtype = torch .float32 ) - 0.5 ) * 2 * fp8_max
122119
123- _ , block_k = block_size [0 ], block_size [1 ]
124-
125- A_fp8 , As_fp8 = per_token_group_quant_fp8 (A_fp32 , block_k )
120+ A_fp8 , As_fp8 = per_token_group_cast_to_fp8 (A_fp32 , block_size [1 ])
126121 B_fp8 , Bs_fp8 = per_block_cast_to_fp8 (B_fp32 )
127122
128123 As = As_fp8 .to (torch .float32 )
@@ -132,14 +127,14 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
132127 out_dtype )
133128
134129 # Transpose earlier so that the testing will not trigger transposing kernels
135- As_fp8 = deep_gemm . get_col_major_tma_aligned_tensor (As_fp8 )
130+ As_fp8 = get_col_major_tma_aligned_tensor (As_fp8 )
136131
137132 out = torch .zeros ((M , N ), device = 'cuda' , dtype = out_dtype )
138133
139134 assert As_fp8 .shape == (M , (K + 127 ) //
140135 128 ), f"{ As_fp8 .shape } != { (M , (K + 127 ) // 128 )} "
141136
142- deep_gemm . gemm_fp8_fp8_bf16_nt ((A_fp8 , As_fp8 ), (B_fp8 , Bs_fp8 ), out )
137+ fp8_gemm_nt ((A_fp8 , As_fp8 ), (B_fp8 , Bs_fp8 ), out )
143138
144139 rel_diff = (torch .mean (
145140 torch .abs (out .to (torch .float32 ) - ref_out .to (torch .float32 ))) /
0 commit comments