diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index c6c8e0b0b936..602fad181074 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -14,7 +14,7 @@ from tqdm import tqdm from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - _w8a8_block_fp8_matmul, + _w8a8_triton_block_scaled_mm, ) from vllm.platforms import current_platform from vllm.triton_utils import triton @@ -83,7 +83,7 @@ def grid(META): ) if A.dtype == torch.float8_e4m3fn: - kernel = _w8a8_block_fp8_matmul + kernel = _w8a8_triton_block_scaled_mm else: raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")