Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,63 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
return ref_out, ref_scale.view((1, ))


def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, Bs: torch.Tensor, block_size,
output_dtype):
"""This function performs matrix multiplication with block-wise
quantization using native torch.
It is agnostic to the input data type and can be used for both int8 and
fp8 data types.

It takes two input tensors `A` and `B` (int8) with scales `As` and
`Bs` (float32).
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]

M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N, )
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]

C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)

A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
]
B_tiles = [[
B[
j * block_n:min((j + 1) * block_n, N),
i * block_k:min((i + 1) * block_k, K),
] for i in range(k_tiles)
] for j in range(n_tiles)]
C_tiles = [
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
]
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]

for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s

C = C.reshape(origin_C_shape).to(output_dtype)
return C
2 changes: 1 addition & 1 deletion tests/kernels/quantization/test_block_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import torch

from tests.kernels.utils_block import native_w8a8_block_matmul
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/quantization/test_block_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import torch

from tests.kernels.utils_block import native_w8a8_block_matmul
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
Expand Down
63 changes: 0 additions & 63 deletions tests/kernels/utils_block.py

This file was deleted.