Skip to content

Commit faed43b

Browse files
committed
Enable fp8 bias support, and added corresponding tests. Adjusted skinny gemm tests to be zero-centered, to avoid saturation and false passes.
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
1 parent b43cee1 commit faed43b

File tree

3 files changed

+81
-24
lines changed

3 files changed

+81
-24
lines changed

csrc/rocm/skinny_gemms.cu

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,7 +1338,7 @@ template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
13381338
__global__ void __launch_bounds__(WvPrGrp* THRDS)
13391339
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx,
13401340
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
1341-
const fp8_t* __restrict__ BIAS, scalar_t* C,
1341+
const scalar_t* __restrict__ BIAS, scalar_t* C,
13421342
const float* __restrict__ s_A,
13431343
const float* __restrict__ s_B, const int _WvPrGrp,
13441344
const int CuCount) {
@@ -1491,8 +1491,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
14911491
for (int n = 0; n < N; n++) {
14921492
for (int y = 0; y < YTILE; y++) {
14931493
// TODO: Determine data type conversion of bias for fp8
1494-
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA *
1495-
sB); // + BIAS[(m+y)%Bx]);
1494+
scalar_t out = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
1495+
C[m + y + n * M] = BIAS ? (out + BIAS[(m + y) % Bx + (n % By) * M]) : out;
14961496
}
14971497
}
14981498
}
@@ -1506,7 +1506,7 @@ template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
15061506
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
15071507
const int Bx, const int By, const fp8_t* B,
15081508
const fp8_t* __restrict__ A,
1509-
const fp8_t* __restrict__ BIAS, scalar_t* C,
1509+
const scalar_t* __restrict__ BIAS, scalar_t* C,
15101510
const float* __restrict__ s_A,
15111511
const float* __restrict__ s_B,
15121512
const int _WvPrGrp, const int CuCount) {
@@ -1520,7 +1520,7 @@ template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
15201520
__global__ void __launch_bounds__(WvPrGrp* THRDS)
15211521
wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx,
15221522
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
1523-
const fp8_t* __restrict__ BIAS, scalar_t* C,
1523+
const scalar_t* __restrict__ BIAS, scalar_t* C,
15241524
const float* __restrict__ s_A, const float* __restrict__ s_B,
15251525
const int _WvPrGrp, const int CuCount) {
15261526
constexpr int max_lds_len = LDS_SIZE;
@@ -1668,9 +1668,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
16681668
for (int n = 0; n < N; n++) {
16691669
for (int y = 0; y < YTILE; y++) {
16701670
if (y + m >= M) break; // To avoid mem access fault.
1671-
// TODO: Determine data type conversion of bias for fp8
1672-
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA *
1673-
sB); // + BIAS[(m+y)%Bx]);
1671+
scalar_t out = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
1672+
C[m + y + n * M] = BIAS ? (out + BIAS[(m + y) % Bx + (n % By) * M]) : out;
16741673
}
16751674
}
16761675
}
@@ -1684,7 +1683,7 @@ template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
16841683
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
16851684
const int Bx, const int By, const fp8_t* B,
16861685
const fp8_t* __restrict__ A,
1687-
const fp8_t* __restrict__ BIAS, scalar_t* C,
1686+
const scalar_t* __restrict__ BIAS, scalar_t* C,
16881687
const float* __restrict__ s_A,
16891688
const float* __restrict__ s_B, const int _WvPrGrp,
16901689
const int CuCount) {
@@ -1750,7 +1749,7 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
17501749
auto a_ptr = in_a.data_ptr<fp8_t>();
17511750
auto b_ptr = in_b.data_ptr<fp8_t>();
17521751
auto bias_ptr = (in_bias.has_value() && in_bias->numel() > 0)
1753-
? in_bias->data_ptr<fp8_t>()
1752+
? reinterpret_cast<fptype*>(in_bias->data_ptr())
17541753
: nullptr;
17551754
switch (N_in) {
17561755
case 1:

tests/kernels/quantization/test_rocm_skinny_gemms.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import pytest
44
import torch
5+
import math
56

67
import vllm._custom_ops as ops
78
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
@@ -47,6 +48,7 @@
4748
(2, 512, 512),
4849
(3, 2048, 2048),
4950
(4, 4096, 4096),
51+
(4, 16400, 2048),
5052
# Extended FP8 dimensions not covered by WVSPLITK
5153
(1, 14336, 1024),
5254
(2, 24576, 2048),
@@ -65,6 +67,8 @@
6567
@torch.inference_mode()
6668
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
6769
torch.manual_seed(seed)
70+
#TODO: Zero-centering the inputs causes errors for LLMM1!
71+
# Without that the numbers quickly saturate, and may be giving false matches.
6872
A = torch.rand(n, k, dtype=dtype, device="cuda")
6973
B = torch.rand(m, k, dtype=dtype, device="cuda")
7074

@@ -83,8 +87,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
8387
torch.manual_seed(seed)
8488
cu_count = current_platform.get_cu_count()
8589

86-
A = torch.rand(n, k, dtype=dtype, device="cuda")
87-
B = torch.rand(m, k, dtype=dtype, device="cuda")
90+
A = torch.rand(n, k, dtype=dtype, device="cuda")-.5
91+
B = torch.rand(m, k, dtype=dtype, device="cuda")-.5
8892

8993
ref_out = torch.matmul(A, B.t())
9094
out = ops.wvSplitK(B, A, cu_count)
@@ -101,9 +105,10 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
101105
torch.manual_seed(seed)
102106
cu_count = current_platform.get_cu_count()
103107

104-
A = torch.rand(n, k, dtype=dtype, device="cuda")
105-
B = torch.rand(m, k, dtype=dtype, device="cuda")
106-
BIAS = torch.rand(m, dtype=dtype, device="cuda")
108+
xavier = math.sqrt(2/k) # normalize to avoid large output-bias deltas
109+
A = (torch.rand(n, k, dtype=dtype, device="cuda")-.5)*xavier
110+
B = (torch.rand(m, k, dtype=dtype, device="cuda")-.5)*xavier
111+
BIAS = torch.rand(m, dtype=dtype, device="cuda")-.5
107112

108113
ref_out = torch.matmul(A, B.t()) + BIAS
109114
out = ops.wvSplitK(B, A, cu_count, BIAS)
@@ -120,16 +125,16 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
120125
torch.manual_seed(seed)
121126
cu_count = current_platform.get_cu_count()
122127

123-
A = torch.rand(n, k, dtype=dtype, device="cuda")
124-
B = torch.rand(m, k, dtype=dtype, device="cuda")
125-
BIAS = torch.rand(n, m, dtype=dtype, device="cuda")
128+
xavier = math.sqrt(2/k) # normalize to avoid large output-bias deltas
129+
A = (torch.rand(n, k, dtype=dtype, device="cuda")-.5)*xavier
130+
B = (torch.rand(m, k, dtype=dtype, device="cuda")-.5)*xavier
131+
BIAS = torch.rand(n, m, dtype=dtype, device="cuda")-.5
126132

127133
ref_out = torch.matmul(A, B.t()) + BIAS
128134
out = ops.wvSplitK(B, A, cu_count, BIAS)
129135

130136
assert torch.allclose(out, ref_out, rtol=0.01)
131137

132-
133138
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
134139
@pytest.mark.parametrize("dtype", DTYPES)
135140
@pytest.mark.parametrize("seed", SEEDS)
@@ -139,8 +144,8 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
139144
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
140145
torch.manual_seed(seed)
141146

142-
A = torch.rand(n, k, device="cuda")
143-
B = torch.rand(m, k, device="cuda")
147+
A = torch.rand(n, k, device="cuda")-0.5
148+
B = torch.rand(m, k, device="cuda")-0.5
144149

145150
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
146151
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
@@ -154,3 +159,57 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
154159
current_platform.get_cu_count())
155160

156161
assert torch.allclose(out, ref_out, rtol=0.01)
162+
163+
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
164+
@pytest.mark.parametrize("dtype", DTYPES)
165+
@pytest.mark.parametrize("seed", SEEDS)
166+
@pytest.mark.skipif(
167+
not (current_platform.is_rocm() and current_platform.supports_fp8()),
168+
reason="only test for rocm fp8")
169+
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
170+
torch.manual_seed(seed)
171+
172+
xavier = math.sqrt(2/k) # normalize to avoid large output-bias deltas
173+
A = (torch.rand(n, k, device="cuda")-.5)*xavier
174+
B = (torch.rand(m, k, device="cuda")-.5)*xavier
175+
BIAS = (torch.rand(m, dtype=dtype, device="cuda")-.5)
176+
177+
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
178+
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
179+
180+
ref_out = torch._scaled_mm(A,
181+
B.t(),
182+
out_dtype=dtype,
183+
scale_a=scale_a,
184+
scale_b=scale_b) + BIAS
185+
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
186+
current_platform.get_cu_count(), BIAS)
187+
188+
assert torch.allclose(out, ref_out, rtol=0.01)
189+
190+
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
191+
@pytest.mark.parametrize("dtype", DTYPES)
192+
@pytest.mark.parametrize("seed", SEEDS)
193+
@pytest.mark.skipif(
194+
not (current_platform.is_rocm() and current_platform.supports_fp8()),
195+
reason="only test for rocm fp8")
196+
def test_rocm_wvsplitk_fp8_bias2D_kernel(n, k, m, dtype, seed):
197+
torch.manual_seed(seed)
198+
199+
xavier = math.sqrt(2/k) # normalize to avoid large output-bias deltas
200+
A = (torch.rand(n, k, device="cuda")-.5)*xavier
201+
B = (torch.rand(m, k, device="cuda")-.5)*xavier
202+
BIAS = torch.rand(n, m, dtype=dtype, device="cuda")-.5
203+
204+
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
205+
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
206+
207+
ref_out = torch._scaled_mm(A,
208+
B.t(),
209+
out_dtype=dtype,
210+
scale_a=scale_a,
211+
scale_b=scale_b) + BIAS
212+
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
213+
current_platform.get_cu_count(), BIAS)
214+
215+
assert torch.allclose(out, ref_out, rtol=0.01)

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,9 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
178178
scale_b: torch.Tensor,
179179
bias: torch.Tensor) -> torch.Tensor:
180180
from vllm.platforms.rocm import on_mi3xx
181-
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
182-
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
181+
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx() and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias.dtype == out_dtype :
183182
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
184-
current_platform.get_cu_count())
183+
current_platform.get_cu_count(), bias)
185184
else:
186185
output = torch._scaled_mm(qinput,
187186
weight,

0 commit comments

Comments
 (0)