22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import pytest
44import torch
5+ import math
56
67import vllm ._custom_ops as ops
78from tests .kernels .quant_utils import ref_dynamic_per_tensor_fp8_quant
4748 (2 , 512 , 512 ),
4849 (3 , 2048 , 2048 ),
4950 (4 , 4096 , 4096 ),
51+ (4 , 16400 , 2048 ),
5052 # Extended FP8 dimensions not covered by WVSPLITK
5153 (1 , 14336 , 1024 ),
5254 (2 , 24576 , 2048 ),
6567@torch .inference_mode ()
6668def test_rocm_llmm1_kernel (n , k , m , dtype , rows_per_block , seed ):
6769 torch .manual_seed (seed )
70+ #TODO: Zero-centering the inputs causes errors for LLMM1!
71+ # Without that the numbers quickly saturate, and may be giving false matches.
6872 A = torch .rand (n , k , dtype = dtype , device = "cuda" )
6973 B = torch .rand (m , k , dtype = dtype , device = "cuda" )
7074
@@ -83,8 +87,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
8387 torch .manual_seed (seed )
8488 cu_count = current_platform .get_cu_count ()
8589
86- A = torch .rand (n , k , dtype = dtype , device = "cuda" )
87- B = torch .rand (m , k , dtype = dtype , device = "cuda" )
90+ A = torch .rand (n , k , dtype = dtype , device = "cuda" )- .5
91+ B = torch .rand (m , k , dtype = dtype , device = "cuda" )- .5
8892
8993 ref_out = torch .matmul (A , B .t ())
9094 out = ops .wvSplitK (B , A , cu_count )
@@ -101,9 +105,10 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
101105 torch .manual_seed (seed )
102106 cu_count = current_platform .get_cu_count ()
103107
104- A = torch .rand (n , k , dtype = dtype , device = "cuda" )
105- B = torch .rand (m , k , dtype = dtype , device = "cuda" )
106- BIAS = torch .rand (m , dtype = dtype , device = "cuda" )
108+ xavier = math .sqrt (2 / k ) # normalize to avoid large output-bias deltas
109+ A = (torch .rand (n , k , dtype = dtype , device = "cuda" )- .5 )* xavier
110+ B = (torch .rand (m , k , dtype = dtype , device = "cuda" )- .5 )* xavier
111+ BIAS = torch .rand (m , dtype = dtype , device = "cuda" )- .5
107112
108113 ref_out = torch .matmul (A , B .t ()) + BIAS
109114 out = ops .wvSplitK (B , A , cu_count , BIAS )
@@ -120,16 +125,16 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
120125 torch .manual_seed (seed )
121126 cu_count = current_platform .get_cu_count ()
122127
123- A = torch .rand (n , k , dtype = dtype , device = "cuda" )
124- B = torch .rand (m , k , dtype = dtype , device = "cuda" )
125- BIAS = torch .rand (n , m , dtype = dtype , device = "cuda" )
128+ xavier = math .sqrt (2 / k ) # normalize to avoid large output-bias deltas
129+ A = (torch .rand (n , k , dtype = dtype , device = "cuda" )- .5 )* xavier
130+ B = (torch .rand (m , k , dtype = dtype , device = "cuda" )- .5 )* xavier
131+ BIAS = torch .rand (n , m , dtype = dtype , device = "cuda" )- .5
126132
127133 ref_out = torch .matmul (A , B .t ()) + BIAS
128134 out = ops .wvSplitK (B , A , cu_count , BIAS )
129135
130136 assert torch .allclose (out , ref_out , rtol = 0.01 )
131137
132-
133138@pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_WVSPLITK_FP8 )
134139@pytest .mark .parametrize ("dtype" , DTYPES )
135140@pytest .mark .parametrize ("seed" , SEEDS )
@@ -139,8 +144,8 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
139144def test_rocm_wvsplitk_fp8_kernel (n , k , m , dtype , seed ):
140145 torch .manual_seed (seed )
141146
142- A = torch .rand (n , k , device = "cuda" )
143- B = torch .rand (m , k , device = "cuda" )
147+ A = torch .rand (n , k , device = "cuda" )- 0.5
148+ B = torch .rand (m , k , device = "cuda" )- 0.5
144149
145150 A , scale_a = ref_dynamic_per_tensor_fp8_quant (A )
146151 B , scale_b = ref_dynamic_per_tensor_fp8_quant (B )
@@ -154,3 +159,57 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
154159 current_platform .get_cu_count ())
155160
156161 assert torch .allclose (out , ref_out , rtol = 0.01 )
162+
163+ @pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_WVSPLITK_FP8 )
164+ @pytest .mark .parametrize ("dtype" , DTYPES )
165+ @pytest .mark .parametrize ("seed" , SEEDS )
166+ @pytest .mark .skipif (
167+ not (current_platform .is_rocm () and current_platform .supports_fp8 ()),
168+ reason = "only test for rocm fp8" )
169+ def test_rocm_wvsplitk_fp8_bias1D_kernel (n , k , m , dtype , seed ):
170+ torch .manual_seed (seed )
171+
172+ xavier = math .sqrt (2 / k ) # normalize to avoid large output-bias deltas
173+ A = (torch .rand (n , k , device = "cuda" )- .5 )* xavier
174+ B = (torch .rand (m , k , device = "cuda" )- .5 )* xavier
175+ BIAS = (torch .rand (m , dtype = dtype , device = "cuda" )- .5 )
176+
177+ A , scale_a = ref_dynamic_per_tensor_fp8_quant (A )
178+ B , scale_b = ref_dynamic_per_tensor_fp8_quant (B )
179+
180+ ref_out = torch ._scaled_mm (A ,
181+ B .t (),
182+ out_dtype = dtype ,
183+ scale_a = scale_a ,
184+ scale_b = scale_b ) + BIAS
185+ out = ops .wvSplitKQ (B , A , dtype , scale_a , scale_b ,
186+ current_platform .get_cu_count (), BIAS )
187+
188+ assert torch .allclose (out , ref_out , rtol = 0.01 )
189+
190+ @pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_WVSPLITK_FP8 )
191+ @pytest .mark .parametrize ("dtype" , DTYPES )
192+ @pytest .mark .parametrize ("seed" , SEEDS )
193+ @pytest .mark .skipif (
194+ not (current_platform .is_rocm () and current_platform .supports_fp8 ()),
195+ reason = "only test for rocm fp8" )
196+ def test_rocm_wvsplitk_fp8_bias2D_kernel (n , k , m , dtype , seed ):
197+ torch .manual_seed (seed )
198+
199+ xavier = math .sqrt (2 / k ) # normalize to avoid large output-bias deltas
200+ A = (torch .rand (n , k , device = "cuda" )- .5 )* xavier
201+ B = (torch .rand (m , k , device = "cuda" )- .5 )* xavier
202+ BIAS = torch .rand (n , m , dtype = dtype , device = "cuda" )- .5
203+
204+ A , scale_a = ref_dynamic_per_tensor_fp8_quant (A )
205+ B , scale_b = ref_dynamic_per_tensor_fp8_quant (B )
206+
207+ ref_out = torch ._scaled_mm (A ,
208+ B .t (),
209+ out_dtype = dtype ,
210+ scale_a = scale_a ,
211+ scale_b = scale_b ) + BIAS
212+ out = ops .wvSplitKQ (B , A , dtype , scale_a , scale_b ,
213+ current_platform .get_cu_count (), BIAS )
214+
215+ assert torch .allclose (out , ref_out , rtol = 0.01 )
0 commit comments