11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ import math
4+
35import pytest
46import torch
57
68import vllm ._custom_ops as ops
79from tests .kernels .quant_utils import ref_dynamic_per_tensor_fp8_quant
8- from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
9- rocm_per_tensor_w8a8_scaled_mm_impl )
1010from vllm .platforms import current_platform
1111
1212DTYPES = [torch .bfloat16 , torch .float16 ]
4949 (2 , 512 , 512 ),
5050 (3 , 2048 , 2048 ),
5151 (4 , 4096 , 4096 ),
52+ (4 , 16400 , 2048 ),
5253 # Extended FP8 dimensions not covered by WVSPLITK
5354 (1 , 14336 , 1024 ),
5455 (2 , 24576 , 2048 ),
6768@torch .inference_mode ()
6869def test_rocm_llmm1_kernel (n , k , m , dtype , rows_per_block , seed ):
6970 torch .manual_seed (seed )
71+ #TODO: Zero-centering the inputs causes errors for LLMM1!
72+ # Without that the numbers quickly saturate, and may
73+ # be giving false matches.
7074 A = torch .rand (n , k , dtype = dtype , device = "cuda" )
7175 B = torch .rand (m , k , dtype = dtype , device = "cuda" )
7276
@@ -85,11 +89,51 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
8589 torch .manual_seed (seed )
8690 cu_count = current_platform .get_cu_count ()
8791
88- A = torch .rand (n , k , dtype = dtype , device = "cuda" )
89- B = torch .rand (m , k , dtype = dtype , device = "cuda" )
92+ A = torch .rand (n , k , dtype = dtype , device = "cuda" ) - .5
93+ B = torch .rand (m , k , dtype = dtype , device = "cuda" ) - .5
9094
91- ref_out = torch .matmul (A , B .t ())
92- out = ops .wvSplitK (B , A , cu_count )
95+ ref_out = torch .nn .functional .linear (A , B )
96+ out = ops .wvSplitK (B , A .view (- 1 , A .size (- 1 )), cu_count )
97+
98+ assert torch .allclose (out , ref_out , rtol = 0.01 )
99+
100+
101+ @pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_WVSPLITK )
102+ @pytest .mark .parametrize ("dtype" , DTYPES )
103+ @pytest .mark .parametrize ("seed" , SEEDS )
104+ @pytest .mark .skipif (not current_platform .is_rocm (),
105+ reason = "only test for rocm" )
106+ def test_rocm_wvsplitk_bias1D_kernel (n , k , m , dtype , seed ):
107+ torch .manual_seed (seed )
108+ cu_count = current_platform .get_cu_count ()
109+
110+ xavier = math .sqrt (2 / k ) # normalize to avoid large output-bias deltas
111+ A = (torch .rand (n , k , dtype = dtype , device = "cuda" ) - .5 ) * xavier
112+ B = (torch .rand (m , k , dtype = dtype , device = "cuda" ) - .5 ) * xavier
113+ BIAS = torch .rand (m , dtype = dtype , device = "cuda" ) - .5
114+
115+ ref_out = torch .nn .functional .linear (A , B , BIAS )
116+ out = ops .wvSplitK (B , A .view (- 1 , A .size (- 1 )), cu_count , BIAS )
117+
118+ assert torch .allclose (out , ref_out , rtol = 0.01 )
119+
120+
121+ @pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_WVSPLITK )
122+ @pytest .mark .parametrize ("dtype" , DTYPES )
123+ @pytest .mark .parametrize ("seed" , SEEDS )
124+ @pytest .mark .skipif (not current_platform .is_rocm (),
125+ reason = "only test for rocm" )
126+ def test_rocm_wvsplitk_bias2D_kernel (n , k , m , dtype , seed ):
127+ torch .manual_seed (seed )
128+ cu_count = current_platform .get_cu_count ()
129+
130+ xavier = math .sqrt (2 / k ) # normalize to avoid large output-bias deltas
131+ A = (torch .rand (n , k , dtype = dtype , device = "cuda" ) - .5 ) * xavier
132+ B = (torch .rand (m , k , dtype = dtype , device = "cuda" ) - .5 ) * xavier
133+ BIAS = torch .rand (n , m , dtype = dtype , device = "cuda" ) - .5
134+
135+ ref_out = torch .nn .functional .linear (A , B , BIAS )
136+ out = ops .wvSplitK (B , A .view (- 1 , A .size (- 1 )), cu_count , BIAS )
93137
94138 assert torch .allclose (out , ref_out , rtol = 0.01 )
95139
@@ -103,8 +147,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
103147def test_rocm_wvsplitk_fp8_kernel (n , k , m , dtype , seed ):
104148 torch .manual_seed (seed )
105149
106- A = torch .rand (n , k , device = "cuda" )
107- B = torch .rand (m , k , device = "cuda" )
150+ A = torch .rand (n , k , device = "cuda" ) - 0.5
151+ B = torch .rand (m , k , device = "cuda" ) - 0.5
108152
109153 A , scale_a = ref_dynamic_per_tensor_fp8_quant (A )
110154 B , scale_b = ref_dynamic_per_tensor_fp8_quant (B )
@@ -123,27 +167,27 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
123167@pytest .mark .parametrize ("n,k,m" , NKM_FACTORS_WVSPLITK_FP8 )
124168@pytest .mark .parametrize ("dtype" , DTYPES )
125169@pytest .mark .parametrize ("seed" , SEEDS )
126- @pytest .mark .parametrize ("use_bias" , [True , False ])
127170@pytest .mark .skipif (
128171 not (current_platform .is_rocm () and current_platform .supports_fp8 ()),
129172 reason = "only test for rocm fp8" )
130- def test_rocm_per_tensor_w8a8_scaled_mm_impl (n , k , m , dtype , seed , use_bias ):
173+ def test_rocm_wvsplitk_fp8_bias1D_kernel (n , k , m , dtype , seed ):
131174 torch .manual_seed (seed )
132175
133- A = torch .rand (n , k , device = "cuda" )
134- B = torch .rand (m , k , device = "cuda" )
176+ xavier = math .sqrt (2 / k ) # normalize to avoid large output-bias deltas
177+ A = (torch .rand (n , k , device = "cuda" ) - .5 ) * xavier
178+ B = (torch .rand (m , k , device = "cuda" ) - .5 ) * xavier
179+ BIAS = torch .rand (m , dtype = dtype , device = "cuda" ) - .5
135180
136181 A , scale_a = ref_dynamic_per_tensor_fp8_quant (A )
137182 B , scale_b = ref_dynamic_per_tensor_fp8_quant (B )
138183
139- bias = torch .rand (1 , m , dtype = dtype , device = "cuda" ) if use_bias else None
140-
141- output = rocm_per_tensor_w8a8_scaled_mm_impl (A , B .t (), dtype , scale_a ,
142- scale_b , bias )
143184 ref_out = torch ._scaled_mm (A ,
144185 B .t (),
145186 out_dtype = dtype ,
146187 scale_a = scale_a ,
147188 scale_b = scale_b ,
148- bias = bias )
149- assert torch .allclose (output , ref_out , rtol = 0.01 )
189+ bias = BIAS )
190+ out = ops .wvSplitKQ (B , A , dtype , scale_a , scale_b ,
191+ current_platform .get_cu_count (), BIAS )
192+
193+ assert torch .allclose (out , ref_out , rtol = 0.01 )
0 commit comments