Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/kernels/moe/test_deepgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
per_token_group_cast_to_fp8)
from vllm.utils.deep_gemm import calc_diff, per_block_cast_to_fp8

BLOCK_SIZE = [128, 128]

Expand Down Expand Up @@ -81,7 +82,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
"""
tokens_bf16 = torch.randn(
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
_, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])

# expert weight tensors
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
Expand Down
5 changes: 2 additions & 3 deletions tests/kernels/quantization/test_block_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (fp8_gemm_nt, per_block_cast_to_fp8,
per_token_group_cast_to_fp8)
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8

if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
Expand Down Expand Up @@ -117,7 +116,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max

A_fp8, As_fp8 = per_token_group_cast_to_fp8(A_fp32, block_size[1])
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)

As = As_fp8.to(torch.float32)
Expand Down
13 changes: 7 additions & 6 deletions vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm, round_up
from vllm.utils.deep_gemm import (m_grouped_fp8_gemm_nt_contiguous,
per_token_group_cast_to_fp8)
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous

logger = init_logger(__name__)

Expand Down Expand Up @@ -170,10 +171,10 @@ def apply(
self.activation(activation, act_out, mm1_out.view(-1, N))

a2q_scale: Optional[torch.Tensor] = None
a2q, a2q_scale = per_token_group_cast_to_fp8(act_out,
self.block_shape[1],
column_major_scales=True,
out_q=quant_out)
a2q, a2q_scale = per_token_group_quant_fp8(act_out,
self.block_shape[1],
column_major_scales=True,
out_q=quant_out)

m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
mm2_out, expert_ids)
Expand Down
7 changes: 1 addition & 6 deletions vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used,
per_token_group_cast_to_fp8)


@triton.jit
Expand Down Expand Up @@ -119,10 +117,7 @@ def _fp8_quantize(
assert not per_act_token
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
if is_blackwell_deep_gemm_used():
A, A_scale = per_token_group_cast_to_fp8(A, block_k)
else:
A, A_scale = per_token_group_quant_fp8(A, block_k)
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)

return A, A_scale
Expand Down
15 changes: 12 additions & 3 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used

logger = init_logger(__name__)

Expand Down Expand Up @@ -256,6 +257,7 @@ def _per_token_group_quant_fp8(
# Information for float8
fp8_min,
fp8_max,
use_ue8m0: tl.constexpr,
# Meta-parameters
BLOCK: tl.constexpr,
):
Expand Down Expand Up @@ -285,7 +287,8 @@ def _per_token_group_quant_fp8(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

tl.store(y_q_ptr + cols, y_q, mask=mask)
Expand All @@ -309,6 +312,7 @@ def _per_token_group_quant_fp8_colmajor(
# Information for float8
fp8_min,
fp8_max,
use_ue8m0: tl.constexpr,
# Meta-parameters
BLOCK: tl.constexpr,
):
Expand Down Expand Up @@ -347,7 +351,8 @@ def _per_token_group_quant_fp8_colmajor(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

tl.store(y_q_ptr + cols, y_q, mask=mask)
Expand All @@ -373,9 +378,11 @@ def per_token_group_quant_fp8(
is supported for now.
column_major_scales: Outputs scales in column major.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
Expand Down Expand Up @@ -418,6 +425,7 @@ def per_token_group_quant_fp8(
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=is_blackwell_deep_gemm_used(),
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
Expand All @@ -433,6 +441,7 @@ def per_token_group_quant_fp8(
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=is_blackwell_deep_gemm_used(),
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
Expand Down
21 changes: 0 additions & 21 deletions vllm/utils/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_per_token_cast_impl: Callable[..., Any] | None = None
_per_block_cast_impl: Callable[..., Any] | None = None
else:
_dg = importlib.import_module("deep_gemm") # type: ignore
Expand All @@ -74,12 +73,9 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
try:
_math_mod = importlib.import_module(
"deep_gemm.utils.math") # type: ignore
_per_token_cast_impl = getattr(_math_mod, "per_token_cast_to_fp8",
None)
_per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
None)
except ModuleNotFoundError:
_per_token_cast_impl = None
_per_block_cast_impl = None


Expand All @@ -101,22 +97,6 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
return _grouped_masked_impl(*args, **kwargs)


def per_token_group_cast_to_fp8(x, group_size, *args, **kwargs):
"""Wrapper for token-wise FP8 quantisation.

• If DeepGEMM provides ``per_token_cast_to_fp8`` (new API), use it.
• Otherwise, fall back to vLLM's ``per_token_group_quant_fp8``
"""

if _per_token_cast_impl is not None and is_blackwell_deep_gemm_used():
assert group_size == 128, "group_size must be 128 for deepgemm"
return _per_token_cast_impl(x)

from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8 as _ptg)
return _ptg(x, group_size, *args, **kwargs)


def per_block_cast_to_fp8(x, *args, **kwargs):
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
return _per_block_cast_impl(x)
Expand Down Expand Up @@ -146,7 +126,6 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
"fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
"per_token_group_cast_to_fp8",
"per_block_cast_to_fp8",
"is_blackwell_deep_gemm_used",
]