11# SPDX-License-Identifier: Apache-2.0 
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 
3+ import  math 
4+ 
35import  pytest 
46import  torch 
57
68import  vllm ._custom_ops  as  ops 
79from  tests .kernels .quant_utils  import  ref_dynamic_per_tensor_fp8_quant 
8- from  vllm .model_executor .layers .quantization .utils .w8a8_utils  import  (
9-     rocm_per_tensor_w8a8_scaled_mm_impl )
1010from  vllm .platforms  import  current_platform 
1111
1212DTYPES  =  [torch .bfloat16 , torch .float16 ]
4949    (2 , 512 , 512 ),
5050    (3 , 2048 , 2048 ),
5151    (4 , 4096 , 4096 ),
52+     (4 , 16400 , 2048 ),
5253    # Extended FP8 dimensions not covered by WVSPLITK 
5354    (1 , 14336 , 1024 ),
5455    (2 , 24576 , 2048 ),
6768@torch .inference_mode () 
6869def  test_rocm_llmm1_kernel (n , k , m , dtype , rows_per_block , seed ):
6970    torch .manual_seed (seed )
71+     #TODO: Zero-centering the inputs causes errors for LLMM1! 
72+     #      Without that the numbers quickly saturate, and may 
73+     #      be giving false matches. 
7074    A  =  torch .rand (n , k , dtype = dtype , device = "cuda" )
7175    B  =  torch .rand (m , k , dtype = dtype , device = "cuda" )
7276
@@ -85,11 +89,51 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
8589    torch .manual_seed (seed )
8690    cu_count  =  current_platform .get_cu_count ()
8791
88-     A  =  torch .rand (n , k , dtype = dtype , device = "cuda" )
89-     B  =  torch .rand (m , k , dtype = dtype , device = "cuda" )
92+     A  =  torch .rand (n , k , dtype = dtype , device = "cuda" )  -   .5 
93+     B  =  torch .rand (m , k , dtype = dtype , device = "cuda" )  -   .5 
9094
91-     ref_out  =  torch .matmul (A , B .t ())
92-     out  =  ops .wvSplitK (B , A , cu_count )
95+     ref_out  =  torch .nn .functional .linear (A , B )
96+     out  =  ops .wvSplitK (B , A .view (- 1 , A .size (- 1 )), cu_count )
97+ 
98+     assert  torch .allclose (out , ref_out , rtol = 0.01 )
99+ 
100+ 
101+ @pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_WVSPLITK ) 
102+ @pytest .mark .parametrize ("dtype" , DTYPES ) 
103+ @pytest .mark .parametrize ("seed" , SEEDS ) 
104+ @pytest .mark .skipif (not  current_platform .is_rocm (), 
105+                     reason = "only test for rocm" ) 
106+ def  test_rocm_wvsplitk_bias1D_kernel (n , k , m , dtype , seed ):
107+     torch .manual_seed (seed )
108+     cu_count  =  current_platform .get_cu_count ()
109+ 
110+     xavier  =  math .sqrt (2  /  k )  # normalize to avoid large output-bias deltas 
111+     A  =  (torch .rand (n , k , dtype = dtype , device = "cuda" ) -  .5 ) *  xavier 
112+     B  =  (torch .rand (m , k , dtype = dtype , device = "cuda" ) -  .5 ) *  xavier 
113+     BIAS  =  torch .rand (m , dtype = dtype , device = "cuda" ) -  .5 
114+ 
115+     ref_out  =  torch .nn .functional .linear (A , B , BIAS )
116+     out  =  ops .wvSplitK (B , A .view (- 1 , A .size (- 1 )), cu_count , BIAS )
117+ 
118+     assert  torch .allclose (out , ref_out , rtol = 0.01 )
119+ 
120+ 
121+ @pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_WVSPLITK ) 
122+ @pytest .mark .parametrize ("dtype" , DTYPES ) 
123+ @pytest .mark .parametrize ("seed" , SEEDS ) 
124+ @pytest .mark .skipif (not  current_platform .is_rocm (), 
125+                     reason = "only test for rocm" ) 
126+ def  test_rocm_wvsplitk_bias2D_kernel (n , k , m , dtype , seed ):
127+     torch .manual_seed (seed )
128+     cu_count  =  current_platform .get_cu_count ()
129+ 
130+     xavier  =  math .sqrt (2  /  k )  # normalize to avoid large output-bias deltas 
131+     A  =  (torch .rand (n , k , dtype = dtype , device = "cuda" ) -  .5 ) *  xavier 
132+     B  =  (torch .rand (m , k , dtype = dtype , device = "cuda" ) -  .5 ) *  xavier 
133+     BIAS  =  torch .rand (n , m , dtype = dtype , device = "cuda" ) -  .5 
134+ 
135+     ref_out  =  torch .nn .functional .linear (A , B , BIAS )
136+     out  =  ops .wvSplitK (B , A .view (- 1 , A .size (- 1 )), cu_count , BIAS )
93137
94138    assert  torch .allclose (out , ref_out , rtol = 0.01 )
95139
@@ -103,8 +147,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
103147def  test_rocm_wvsplitk_fp8_kernel (n , k , m , dtype , seed ):
104148    torch .manual_seed (seed )
105149
106-     A  =  torch .rand (n , k , device = "cuda" )
107-     B  =  torch .rand (m , k , device = "cuda" )
150+     A  =  torch .rand (n , k , device = "cuda" )  -   0.5 
151+     B  =  torch .rand (m , k , device = "cuda" )  -   0.5 
108152
109153    A , scale_a  =  ref_dynamic_per_tensor_fp8_quant (A )
110154    B , scale_b  =  ref_dynamic_per_tensor_fp8_quant (B )
@@ -123,27 +167,27 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
123167@pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_WVSPLITK_FP8 ) 
124168@pytest .mark .parametrize ("dtype" , DTYPES ) 
125169@pytest .mark .parametrize ("seed" , SEEDS ) 
126- @pytest .mark .parametrize ("use_bias" , [True , False ]) 
127170@pytest .mark .skipif ( 
128171    not  (current_platform .is_rocm () and  current_platform .supports_fp8 ()), 
129172    reason = "only test for rocm fp8" ) 
130- def  test_rocm_per_tensor_w8a8_scaled_mm_impl (n , k , m , dtype , seed ,  use_bias ):
173+ def  test_rocm_wvsplitk_fp8_bias1D_kernel (n , k , m , dtype , seed ):
131174    torch .manual_seed (seed )
132175
133-     A  =  torch .rand (n , k , device = "cuda" )
134-     B  =  torch .rand (m , k , device = "cuda" )
176+     xavier  =  math .sqrt (2  /  k )  # normalize to avoid large output-bias deltas 
177+     A  =  (torch .rand (n , k , device = "cuda" ) -  .5 ) *  xavier 
178+     B  =  (torch .rand (m , k , device = "cuda" ) -  .5 ) *  xavier 
179+     BIAS  =  torch .rand (m , dtype = dtype , device = "cuda" ) -  .5 
135180
136181    A , scale_a  =  ref_dynamic_per_tensor_fp8_quant (A )
137182    B , scale_b  =  ref_dynamic_per_tensor_fp8_quant (B )
138183
139-     bias  =  torch .rand (1 , m , dtype = dtype , device = "cuda" ) if  use_bias  else  None 
140- 
141-     output  =  rocm_per_tensor_w8a8_scaled_mm_impl (A , B .t (), dtype , scale_a ,
142-                                                  scale_b , bias )
143184    ref_out  =  torch ._scaled_mm (A ,
144185                               B .t (),
145186                               out_dtype = dtype ,
146187                               scale_a = scale_a ,
147188                               scale_b = scale_b ,
148-                                bias = bias )
149-     assert  torch .allclose (output , ref_out , rtol = 0.01 )
189+                                bias = BIAS )
190+     out  =  ops .wvSplitKQ (B , A , dtype , scale_a , scale_b ,
191+                         current_platform .get_cu_count (), BIAS )
192+ 
193+     assert  torch .allclose (out , ref_out , rtol = 0.01 )
0 commit comments