Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block);

torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias,
const int64_t CuCount);

void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount);

void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
Expand Down
181 changes: 139 additions & 42 deletions csrc/rocm/skinny_gemms.cu

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions csrc/rocm/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {

// Custom gemm op for skinny matrix-matrix multiplication
rocm_ops.def(
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
"wvSplitK(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> "
"Tensor");
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);

// wvSplitK for fp8
rocm_ops.def(
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, "
"Tensor scale_a, "
" Tensor scale_b, int CuCount) -> ()");
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);

Expand Down
80 changes: 62 additions & 18 deletions tests/kernels/quantization/test_rocm_skinny_gemms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math

import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
rocm_per_tensor_w8a8_scaled_mm_impl)
from vllm.platforms import current_platform

DTYPES = [torch.bfloat16, torch.float16]
Expand Down Expand Up @@ -49,6 +49,7 @@
(2, 512, 512),
(3, 2048, 2048),
(4, 4096, 4096),
(4, 16400, 2048),
# Extended FP8 dimensions not covered by WVSPLITK
(1, 14336, 1024),
(2, 24576, 2048),
Expand All @@ -67,6 +68,9 @@
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
torch.manual_seed(seed)
#TODO: Zero-centering the inputs causes errors for LLMM1!
# Without that the numbers quickly saturate, and may
# be giving false matches.
A = torch.rand(n, k, dtype=dtype, device="cuda")
B = torch.rand(m, k, dtype=dtype, device="cuda")

Expand All @@ -85,11 +89,51 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()

A = torch.rand(n, k, dtype=dtype, device="cuda")
B = torch.rand(m, k, dtype=dtype, device="cuda")
A = torch.rand(n, k, dtype=dtype, device="cuda") - .5
B = torch.rand(m, k, dtype=dtype, device="cuda") - .5

ref_out = torch.matmul(A, B.t())
out = ops.wvSplitK(B, A, cu_count)
ref_out = torch.nn.functional.linear(A, B)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)

assert torch.allclose(out, ref_out, rtol=0.01)


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()

xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5

ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)

assert torch.allclose(out, ref_out, rtol=0.01)


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()

xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5

ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)

assert torch.allclose(out, ref_out, rtol=0.01)

Expand All @@ -103,8 +147,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)

A = torch.rand(n, k, device="cuda")
B = torch.rand(m, k, device="cuda")
A = torch.rand(n, k, device="cuda") - 0.5
B = torch.rand(m, k, device="cuda") - 0.5

A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
Expand All @@ -123,27 +167,27 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8")
def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias):
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)

A = torch.rand(n, k, device="cuda")
B = torch.rand(m, k, device="cuda")
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, device="cuda") - .5) * xavier
B = (torch.rand(m, k, device="cuda") - .5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5

A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)

bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None

output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a,
scale_b, bias)
ref_out = torch._scaled_mm(A,
B.t(),
out_dtype=dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
assert torch.allclose(output, ref_out, rtol=0.01)
bias=BIAS)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
current_platform.get_cu_count(), BIAS)

assert torch.allclose(out, ref_out, rtol=0.01)
23 changes: 15 additions & 8 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,17 +1447,24 @@ def LLMM1(a: torch.Tensor, b: torch.Tensor,
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)


def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor:
return torch.ops._rocm_C.wvSplitK(a, b, cu_count)


def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor,
cu_count: int) -> torch.Tensor:
def wvSplitK(a: torch.Tensor,
b: torch.Tensor,
cu_count: int,
bias: torch.Tensor = None) -> torch.Tensor:
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)


def wvSplitKQ(a: torch.Tensor,
b: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cu_count: int,
bias: torch.Tensor = None) -> torch.Tensor:
out = torch.empty((b.shape[0], a.shape[0]),
dtype=out_dtype,
device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count)
torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
return out


Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None:
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx() and \
qinput.shape[0] == 1 and \
qinput.shape[1] % 16 == 0 and \
((bias is None) or (bias.dtype == out_dtype)) :
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())
current_platform.get_cu_count(), bias)
else:
output = torch._scaled_mm(qinput,
weight,
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def rocm_unquantized_gemm_impl(
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None)
and k % 8 == 0)

if use_skinny is not True:
return torch.nn.functional.linear(x, weight, bias)
Expand All @@ -111,9 +111,9 @@ def rocm_unquantized_gemm_impl(
cu_count = current_platform.get_cu_count()

if m > 8 and 0 < n <= 4:
out = ops.wvSplitK(weight, x_view, cu_count)
out = ops.wvSplitK(weight, x_view, cu_count, bias)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192:
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
out = ops.LLMM1(weight, x_view, 4)
return out.view(*x.shape[:-1], weight.shape[0])
return torch.nn.functional.linear(x, weight, bias)
Expand Down