Skip to content

Commit f9946bc

Browse files
amd-hhashemichoprahetarth
authored andcommitted
[ROCm] Add skinny gemm bias support for dtypes fp16,bf16,fp8 (vllm-project#24988)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com>
1 parent eb2280d commit f9946bc

File tree

7 files changed

+233
-79
lines changed

7 files changed

+233
-79
lines changed

csrc/rocm/ops.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
66
const int64_t rows_per_block);
77

8-
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
8+
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
9+
const c10::optional<at::Tensor>& in_bias,
910
const int64_t CuCount);
1011

11-
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
12-
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
12+
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
13+
const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c,
14+
const at::Tensor& scale_a, const at::Tensor& scale_b,
15+
const int64_t CuCount);
1316

1417
void paged_attention(
1518
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,

csrc/rocm/skinny_gemms.cu

Lines changed: 139 additions & 42 deletions
Large diffs are not rendered by default.

csrc/rocm/torch_bindings.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
2222

2323
// Custom gemm op for skinny matrix-matrix multiplication
2424
rocm_ops.def(
25-
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
25+
"wvSplitK(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> "
2626
"Tensor");
2727
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
2828

2929
// wvSplitK for fp8
3030
rocm_ops.def(
31-
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
31+
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, "
32+
"Tensor scale_a, "
3233
" Tensor scale_b, int CuCount) -> ()");
3334
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
3435

tests/kernels/quantization/test_rocm_skinny_gemms.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import math
4+
35
import pytest
46
import torch
57

68
import vllm._custom_ops as ops
79
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
8-
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
9-
rocm_per_tensor_w8a8_scaled_mm_impl)
1010
from vllm.platforms import current_platform
1111

1212
DTYPES = [torch.bfloat16, torch.float16]
@@ -49,6 +49,7 @@
4949
(2, 512, 512),
5050
(3, 2048, 2048),
5151
(4, 4096, 4096),
52+
(4, 16400, 2048),
5253
# Extended FP8 dimensions not covered by WVSPLITK
5354
(1, 14336, 1024),
5455
(2, 24576, 2048),
@@ -67,6 +68,9 @@
6768
@torch.inference_mode()
6869
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
6970
torch.manual_seed(seed)
71+
#TODO: Zero-centering the inputs causes errors for LLMM1!
72+
# Without that the numbers quickly saturate, and may
73+
# be giving false matches.
7074
A = torch.rand(n, k, dtype=dtype, device="cuda")
7175
B = torch.rand(m, k, dtype=dtype, device="cuda")
7276

@@ -85,11 +89,51 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
8589
torch.manual_seed(seed)
8690
cu_count = current_platform.get_cu_count()
8791

88-
A = torch.rand(n, k, dtype=dtype, device="cuda")
89-
B = torch.rand(m, k, dtype=dtype, device="cuda")
92+
A = torch.rand(n, k, dtype=dtype, device="cuda") - .5
93+
B = torch.rand(m, k, dtype=dtype, device="cuda") - .5
9094

91-
ref_out = torch.matmul(A, B.t())
92-
out = ops.wvSplitK(B, A, cu_count)
95+
ref_out = torch.nn.functional.linear(A, B)
96+
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
97+
98+
assert torch.allclose(out, ref_out, rtol=0.01)
99+
100+
101+
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
102+
@pytest.mark.parametrize("dtype", DTYPES)
103+
@pytest.mark.parametrize("seed", SEEDS)
104+
@pytest.mark.skipif(not current_platform.is_rocm(),
105+
reason="only test for rocm")
106+
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
107+
torch.manual_seed(seed)
108+
cu_count = current_platform.get_cu_count()
109+
110+
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
111+
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
112+
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
113+
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
114+
115+
ref_out = torch.nn.functional.linear(A, B, BIAS)
116+
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
117+
118+
assert torch.allclose(out, ref_out, rtol=0.01)
119+
120+
121+
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
122+
@pytest.mark.parametrize("dtype", DTYPES)
123+
@pytest.mark.parametrize("seed", SEEDS)
124+
@pytest.mark.skipif(not current_platform.is_rocm(),
125+
reason="only test for rocm")
126+
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
127+
torch.manual_seed(seed)
128+
cu_count = current_platform.get_cu_count()
129+
130+
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
131+
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
132+
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
133+
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5
134+
135+
ref_out = torch.nn.functional.linear(A, B, BIAS)
136+
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
93137

94138
assert torch.allclose(out, ref_out, rtol=0.01)
95139

@@ -103,8 +147,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
103147
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
104148
torch.manual_seed(seed)
105149

106-
A = torch.rand(n, k, device="cuda")
107-
B = torch.rand(m, k, device="cuda")
150+
A = torch.rand(n, k, device="cuda") - 0.5
151+
B = torch.rand(m, k, device="cuda") - 0.5
108152

109153
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
110154
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
@@ -123,27 +167,27 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
123167
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
124168
@pytest.mark.parametrize("dtype", DTYPES)
125169
@pytest.mark.parametrize("seed", SEEDS)
126-
@pytest.mark.parametrize("use_bias", [True, False])
127170
@pytest.mark.skipif(
128171
not (current_platform.is_rocm() and current_platform.supports_fp8()),
129172
reason="only test for rocm fp8")
130-
def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias):
173+
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
131174
torch.manual_seed(seed)
132175

133-
A = torch.rand(n, k, device="cuda")
134-
B = torch.rand(m, k, device="cuda")
176+
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
177+
A = (torch.rand(n, k, device="cuda") - .5) * xavier
178+
B = (torch.rand(m, k, device="cuda") - .5) * xavier
179+
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
135180

136181
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
137182
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
138183

139-
bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None
140-
141-
output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a,
142-
scale_b, bias)
143184
ref_out = torch._scaled_mm(A,
144185
B.t(),
145186
out_dtype=dtype,
146187
scale_a=scale_a,
147188
scale_b=scale_b,
148-
bias=bias)
149-
assert torch.allclose(output, ref_out, rtol=0.01)
189+
bias=BIAS)
190+
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
191+
current_platform.get_cu_count(), BIAS)
192+
193+
assert torch.allclose(out, ref_out, rtol=0.01)

vllm/_custom_ops.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,17 +1447,24 @@ def LLMM1(a: torch.Tensor, b: torch.Tensor,
14471447
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
14481448

14491449

1450-
def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor:
1451-
return torch.ops._rocm_C.wvSplitK(a, b, cu_count)
1452-
1453-
1454-
def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
1455-
scale_a: torch.Tensor, scale_b: torch.Tensor,
1456-
cu_count: int) -> torch.Tensor:
1450+
def wvSplitK(a: torch.Tensor,
1451+
b: torch.Tensor,
1452+
cu_count: int,
1453+
bias: torch.Tensor = None) -> torch.Tensor:
1454+
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)
1455+
1456+
1457+
def wvSplitKQ(a: torch.Tensor,
1458+
b: torch.Tensor,
1459+
out_dtype: torch.dtype,
1460+
scale_a: torch.Tensor,
1461+
scale_b: torch.Tensor,
1462+
cu_count: int,
1463+
bias: torch.Tensor = None) -> torch.Tensor:
14571464
out = torch.empty((b.shape[0], a.shape[0]),
14581465
dtype=out_dtype,
14591466
device=b.device)
1460-
torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count)
1467+
torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
14611468
return out
14621469

14631470

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
178178
scale_b: torch.Tensor,
179179
bias: torch.Tensor) -> torch.Tensor:
180180
from vllm.platforms.rocm import on_mi3xx
181-
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
182-
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None:
181+
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx() and \
182+
qinput.shape[0] == 1 and \
183+
qinput.shape[1] % 16 == 0 and \
184+
((bias is None) or (bias.dtype == out_dtype)) :
183185
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
184-
current_platform.get_cu_count())
186+
current_platform.get_cu_count(), bias)
185187
else:
186188
output = torch._scaled_mm(qinput,
187189
weight,

vllm/model_executor/layers/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def rocm_unquantized_gemm_impl(
100100
k = weight.shape[1]
101101
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
102102
x.dtype in [torch.float16, torch.bfloat16] \
103-
and k % 8 == 0 and bias is None)
103+
and k % 8 == 0)
104104

105105
if use_skinny is not True:
106106
return torch.nn.functional.linear(x, weight, bias)
@@ -111,9 +111,9 @@ def rocm_unquantized_gemm_impl(
111111
cu_count = current_platform.get_cu_count()
112112

113113
if m > 8 and 0 < n <= 4:
114-
out = ops.wvSplitK(weight, x_view, cu_count)
114+
out = ops.wvSplitK(weight, x_view, cu_count, bias)
115115
return out.view(*x.shape[:-1], weight.shape[0])
116-
elif m % 4 == 0 and n == 1 and k <= 8192:
116+
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
117117
out = ops.LLMM1(weight, x_view, 4)
118118
return out.view(*x.shape[:-1], weight.shape[0])
119119
return torch.nn.functional.linear(x, weight, bias)

0 commit comments

Comments
 (0)