From e6c942f0331a3eef81531f402b7d1438f69a41df Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 23 Apr 2025 14:43:55 +0000 Subject: [PATCH] Add missing rocm_skinny_gemms kernel test to CI Signed-off-by: mgoin --- tests/kernels/quant_utils.py | 60 ++++++++++++++++++ tests/kernels/quantization/test_block_fp8.py | 2 +- tests/kernels/quantization/test_block_int8.py | 2 +- .../test_rocm_skinny_gemms.py | 0 tests/kernels/utils_block.py | 63 ------------------- 5 files changed, 62 insertions(+), 65 deletions(-) rename tests/kernels/{ => quantization}/test_rocm_skinny_gemms.py (100%) delete mode 100644 tests/kernels/utils_block.py diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 498da6001ae9..764924f26783 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -87,3 +87,63 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ ref_out = (as_float32_tensor(x) * ref_iscale).clamp( fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) return ref_out, ref_scale.view((1, )) + + +def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, + As: torch.Tensor, Bs: torch.Tensor, block_size, + output_dtype): + """This function performs matrix multiplication with block-wise + quantization using native torch. + It is agnostic to the input data type and can be used for both int8 and + fp8 data types. + + It takes two input tensors `A` and `B` (int8) with scales `As` and + `Bs` (float32). + The output is returned in the specified `output_dtype`. + """ + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index da594675e924..c57e39f42506 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -6,7 +6,7 @@ import pytest import torch -from tests.kernels.utils_block import native_w8a8_block_matmul +from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index 943470ad113d..104f23fd7cd2 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -6,7 +6,7 @@ import pytest import torch -from tests.kernels.utils_block import native_w8a8_block_matmul +from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe diff --git a/tests/kernels/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py similarity index 100% rename from tests/kernels/test_rocm_skinny_gemms.py rename to tests/kernels/quantization/test_rocm_skinny_gemms.py diff --git a/tests/kernels/utils_block.py b/tests/kernels/utils_block.py deleted file mode 100644 index c16cba50967e..000000000000 --- a/tests/kernels/utils_block.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, - As: torch.Tensor, Bs: torch.Tensor, block_size, - output_dtype): - """This function performs matrix multiplication with block-wise - quantization using native torch. - It is agnostic to the input data type and can be used for both int8 and - fp8 data types. - - It takes two input tensors `A` and `B` (int8) with scales `As` and - `Bs` (float32). - The output is returned in the specified `output_dtype`. - """ - A = A.to(torch.float32) - B = B.to(torch.float32) - assert A.shape[-1] == B.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - - M = A.numel() // A.shape[-1] - N, K = B.shape - origin_C_shape = A.shape[:-1] + (N, ) - A = A.reshape(M, A.shape[-1]) - As = As.reshape(M, As.shape[-1]) - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] - - C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) - - A_tiles = [ - A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) - ] - B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - C_tiles = [ - C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) - ] - As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] - - for i in range(k_tiles): - for j in range(n_tiles): - a = A_tiles[i] - b = B_tiles[j][i] - c = C_tiles[j] - s = As_tiles[i] * Bs[j][i] - c[:, :] += torch.matmul(a, b.t()) * s - - C = C.reshape(origin_C_shape).to(output_dtype) - return C