From 7cdb800260d2565e478db9aa1b84fcac87644743 Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 2 Jul 2025 20:06:02 +0000 Subject: [PATCH 01/30] Flashinfer cutlass moe backend for TP/DP + EP. --- benchmarks/benchmark_throughput.py | 3 + .../base_device_communicator.py | 18 +- .../device_communicators/cuda_communicator.py | 65 ++++- .../device_communicators/pynccl.py | 74 ++++-- .../device_communicators/pynccl_wrapper.py | 32 +++ vllm/distributed/parallel_state.py | 25 +- vllm/envs.py | 5 + .../model_executor/layers/fused_moe/config.py | 28 +++ .../layers/fused_moe/cutlass_moe.py | 236 ++++++++++++------ .../fused_moe/flashinfer_cutlass_moe.py | 177 +++++++++++++ .../flashinfer_cutlass_prepare_finalize.py | 99 ++++++++ vllm/model_executor/layers/fused_moe/layer.py | 49 +++- .../layers/fused_moe/modular_kernel.py | 34 +-- vllm/model_executor/layers/fused_moe/utils.py | 14 ++ .../layers/quantization/modelopt.py | 190 ++++++++++++-- 15 files changed, 904 insertions(+), 145 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py create mode 100644 vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 0ded34c70bad..45c9fee66553 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -28,6 +28,7 @@ VisionArenaDataset, ) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json +from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args, @@ -110,6 +111,8 @@ def run_vllm( ), ) end = time.perf_counter() + + cleanup_dist_env_and_memory() return end - start, outputs diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 1bc2d8e0281c..69e6f405fdde 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading -from typing import Optional +from typing import List, Optional, Union from weakref import WeakValueDictionary import torch @@ -138,9 +138,23 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: input_size[dim + 1:]) return output_tensor + def all_gatherv(self, + input_: Union[torch.Tensor, List[torch.Tensor]], + dim: int = 0, + sizes: Optional[List[int]] = None): + assert False, "not implemented" + + def all_gatherv(self, + input_: Union[torch.Tensor, List[torch.Tensor]], + dim: int = 0, + sizes: Optional[List[int]] = None): + assert False, "not implemented" + def reduce_scatter(self, input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: + dim: int = -1, + sizes: Optional[List[int]] = None) -> torch.Tensor: + assert sizes is None, "Varying size reduce scatter not supported with base device communicator" world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 3958d566b174..727c64518837 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import List, Optional, Union import torch from torch.distributed import ProcessGroup @@ -117,7 +117,10 @@ def all_reduce(self, input_): torch.distributed.all_reduce(out, group=self.device_group) return out - def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1, + sizes: Optional[List[int]] = None): world_size = self.world_size pynccl_comm = self.pynccl_comm assert pynccl_comm is not None @@ -129,15 +132,20 @@ def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? input_tensor = input_.movedim(0, dim).contiguous() - assert input_tensor.shape[0] % world_size == 0 - chunk_size = input_tensor.shape[0] // world_size + if sizes is not None: + assert len(sizes) == world_size + assert input_tensor.shape[0] == sum(sizes) + chunk_size = sizes[self.rank_in_group] + else: + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size output_shape = (chunk_size, ) + input_tensor.shape[1:] output = torch.empty(output_shape, dtype=input_tensor.dtype, device=input_tensor.device) - pynccl_comm.reduce_scatter(output, input_) + pynccl_comm.reduce_scatter(output, input_, sizes=sizes) # Reshape before returning return output.movedim(0, dim).contiguous() @@ -180,6 +188,53 @@ def destroy(self): self.all2all_manager.destroy() self.all2all_manager = None + """ + Allgather with support for list of tensors and varying sizes per rank. + Example: + Instead of: + ... = get_ep_group().dispatch(...) + Use this: + ... = get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], dim=0, sizes=get_forward_context().dp_metadata.num_tokens_across_dp_cpu) + """ + + def all_gatherv(self, + input_: Union[torch.Tensor, List[torch.Tensor]], + dim: int = 0, + sizes: Optional[List[int]] = None): + assert dim == 0, "only dim 0 all-gather is supported" + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None and not pynccl_comm.disabled + + def _all_gather_single(input_: torch.Tensor, + sizes: Optional[List[int]] = None): + input_size = input_.size() + if sizes is not None: + assert len(sizes) == world_size + assert input_.shape[dim] == sizes[self.rank_in_group] + output_size = (sum(sizes), ) + input_size[1:] + # 'sizes' is not needed if all inputs in the same group have the same shape + if all(s == sizes[0] for s in sizes): + sizes = None + else: + output_size = (input_size[0] * world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + pynccl_comm.all_gather(output_tensor, input_, sizes=sizes) + return output_tensor + + if isinstance(input_, torch.Tensor): + return _all_gather_single(input_, sizes) + + pynccl_comm.group_start() + output_list = [] + for inp in input_: + output_list.append(_all_gather_single(inp, sizes=sizes)) + pynccl_comm.group_end() + return output_list + def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 29486292996a..5829e4e460cf 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import List, Optional, Union +import numpy as np # ===================== import region ===================== import torch import torch.distributed as dist @@ -135,7 +136,8 @@ def all_reduce(self, def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, - stream=None): + stream=None, + sizes: Optional[List[int]] = None): if self.disabled: return # nccl communicator created on a specific device @@ -146,17 +148,38 @@ def all_gather(self, f"but the input tensor is on {input_tensor.device}") if stream is None: stream = current_stream() - self.nccl.ncclAllGather( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), input_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, - cudaStream_t(stream.cuda_stream)) + if sizes is not None: + assert output_tensor.shape[0] == sum(sizes) + numel_base = int(np.prod(output_tensor.shape[1:])) + split_offset = 0 + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + dst_slice = output_tensor[split_offset:split_offset + + split_size] + self.nccl.ncclBroadcast( + buffer_type(input_tensor.data_ptr()), + buffer_type(dst_slice.data_ptr()), + split_size * numel_base, + ncclDataTypeEnum.from_torch(input_tensor.dtype), + root, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + split_offset += split_size + self.nccl.ncclGroupEnd() + else: + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + cudaStream_t(stream.cuda_stream)) def reduce_scatter(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, - stream=None): + stream=None, + sizes: Optional[List[int]] = None): if self.disabled: return # nccl communicator created on a specific device @@ -167,12 +190,29 @@ def reduce_scatter(self, f"but the input tensor is on {input_tensor.device}") if stream is None: stream = current_stream() - self.nccl.ncclReduceScatter( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), output_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - cudaStream_t(stream.cuda_stream)) + + if sizes is not None: + numel_base = int(np.prod(input_tensor.shape[1:])) + split_offset = 0 + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + chunk = input_tensor[split_offset:split_offset + split_size, :] + self.nccl.ncclReduce( + buffer_type(chunk.data_ptr()), + buffer_type(output_tensor.data_ptr()), + split_size * numel_base, + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), root, self.comm, + cudaStream_t(stream.cuda_stream)) + split_offset += split_size + self.nccl.ncclGroupEnd() + else: + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: @@ -216,3 +256,9 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) + + def group_start(self): + self.nccl.ncclGroupStart() + + def group_end(self): + self.nccl.ncclGroupEnd() diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 3018a92da07c..c8e772447ee9 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -154,6 +154,16 @@ class NCCLLibrary: ncclRedOp_t, ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, int root, + # ncclComm_t comm, cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t + ]), # ncclResult_t ncclAllGather( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclComm_t comm, @@ -207,6 +217,10 @@ class NCCLLibrary: # it is better not to call it at all. # ncclResult_t ncclCommDestroy(ncclComm_t comm); Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + # ncclResult_t ncclGroupStart(); + Function("ncclGroupStart", ncclResult_t, []), + # ncclResult_t ncclGroupEnd(); + Function("ncclGroupEnd", ncclResult_t, []), ] # class attribute to store the mapping from the path to the library @@ -300,6 +314,18 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, datatype, op, comm, stream)) + def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, root: int, + comm: ncclComm_t, stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count, + datatype, op, root, comm, + stream)) + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, count: int, datatype: int, op: int, comm: ncclComm_t, stream: cudaStream_t) -> None: @@ -342,6 +368,12 @@ def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + def ncclGroupStart(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupStart"]()) + + def ncclGroupEnd(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c53601a22f21..f09e28c4edc2 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -30,7 +30,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union from unittest.mock import patch import torch @@ -381,9 +381,16 @@ def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: return self.device_communicator.all_gather(input_, dim) + def all_gatherv(self, + input_: Union[torch.Tensor, List[torch.Tensor]], + dim: int = 0, + sizes: Optional[List[int]] = None): + return self.device_communicator.all_gatherv(input_, dim, sizes) + def reduce_scatter(self, input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: + dim: int = -1, + sizes: Optional[List[int]] = None) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: @@ -392,16 +399,20 @@ def reduce_scatter(self, f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") if self.use_custom_op_call: + assert sizes is None, "Varying size reduce scatter not supported with vllm custom op" return torch.ops.vllm.reduce_scatter(input_, dim, world_size, group_name=self.unique_name) else: - return self._reduce_scatter_out_place(input_, dim) - - def _reduce_scatter_out_place(self, input_: torch.Tensor, - dim: int) -> torch.Tensor: - return self.device_communicator.reduce_scatter(input_, dim) + return self._reduce_scatter_out_place(input_, dim, sizes) + + def _reduce_scatter_out_place( + self, + input_: torch.Tensor, + dim: int, + sizes: Optional[List[int]] = None) -> torch.Tensor: + return self.device_communicator.reduce_scatter(input_, dim, sizes) def gather(self, input_: torch.Tensor, diff --git a/vllm/envs.py b/vllm/envs.py index 0cc6792d72bb..22eb7f8f8abe 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -121,6 +121,7 @@ VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_USE_DEEP_GEMM: bool = False + VLLM_USE_FLASHINFER_MOE: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -867,6 +868,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + # Allow use of FlashInfer CUTLASS kernels for fused moe ops. + "VLLM_USE_FLASHINFER_MOE": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))), + # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. # It can be changed with this variable if needed for some reason. diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 9a678406b8f3..02d95cc2210e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -15,6 +15,17 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from typing import TYPE_CHECKING + +try: + from flashinfer import fp4_quantize as fp4_quantize + from flashinfer.fused_moe import ( + cutlass_fused_moe as flashinfer_cutlass_fused_moe) +except ImportError: + if not TYPE_CHECKING: + flashinfer_cutlass_fused_moe = None +has_flashinfer = flashinfer_cutlass_fused_moe is not None + logger = init_logger(__name__) @@ -117,6 +128,11 @@ class FusedMoEParallelConfig: def use_all2all_kernels(self): return self.dp_size > 1 and self.use_ep + @property + def use_pplx_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "pplx") + @property def use_pplx_kernels(self): return (self.use_all2all_kernels @@ -132,6 +148,10 @@ def use_deepep_ll_kernels(self): return (self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + @property + def use_flashinfer_cutlass_kernels(self): + return envs.VLLM_USE_FLASHINFER_MOE and has_flashinfer + @staticmethod def make(tp_size_: int, dp_size_: int, world_size_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": @@ -335,6 +355,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_flashinfer_cutlass_kernels(self): + return self.moe_parallel_config.use_flashinfer_cutlass_kernels + @staticmethod def make( num_experts: int, @@ -377,6 +401,10 @@ def make( from vllm.model_executor.layers.quantization.fp8 import Fp8Config if quant_dtype is None and isinstance(quant_config, Fp8Config): quant_dtype = torch.float8_e4m3fn + + from vllm.model_executor.layers.quantization.modelopt import ModelOptNvFp4Config + if quant_dtype is None and isinstance(quant_config, ModelOptNvFp4Config): + quant_dtype = torch.uint8 if weight_quant is not None: per_out_ch_quant = ( diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 0ef4e4f767e3..213767867ff5 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -393,82 +393,41 @@ def cutlass_moe_fp8( FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max -def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, - w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor, - w1_alphas: torch.Tensor, a2_gscale: torch.Tensor, - w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor, - w2_alphas: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, m: int, n: int, k: int, e: int, - device: torch.device): - """ - MoE implementation for FP4 Inputs - - # Gemm 1 - a: Input tensor: [m, k] (half/bfloat16) - a1_gscale: Activation scale per expert: [e] (float32) - w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k] - w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1) - (Note: `n` is the up projection output dim, `k` is the input dim in - full precision) - w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) - (Block size = 16 for NVFP4) - - # Gemm 2 - a2_gscale: Activation scale per expert: [e] - w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n] - w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1) - w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3 - - topk_weights: [m, topk] dtype: float8 - topk_ids: [m, topk] dtype: float8 - - m, n, k: Unquantized weight shapes, dtype: int - e: number of experts, dtype: int - - assumes that topk < k < n to satisfy - up/down projection expectations. - """ - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8" - assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8" - assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3 - and w2_blockscale.ndim - == 3), ("All Weights must be of rank 3 for cutlass_moe_fp4") - m_a, k_a = a.shape - e_w1, nx2_w1, half_k_w1 = w1_fp4.shape - e_w2, k_w2, half_n_w2 = w2_fp4.shape - - assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", - " between weights.") - assert (k_a // 2 == half_k_w1 - and k == k_w2), ("Hidden size mismatch between a, w1 and w2") - assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in " - "expected `n`") - assert (m == m_a), "input shape mismatch" - assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" - assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" - assert (topk_weights.size(0) == m and topk_ids.size(0) - == m), ("topk must be provided for each row of a") - - out_dtype = a.dtype - num_topk = topk_ids.size(1) +# === FP4 Fused MoE Implementation === + +def run_cutlass_moe_fp4( + output: torch.Tensor, + hidden_states: torch.Tensor, + w1_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, + w2_fp4: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, + topk_ids: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + topk_weights: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, +): + num_topk = topk_ids.shape[1] expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) - # Problem size: (num_experts, (m,2n,k)) problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device) - # Problem size: (num_experts, (m,n,k)) problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device) - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - # problem shapes should have [m, n, k] - # Note that problem sizes are based on logical number of elements. ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, a_map, c_map, e, n, k, blockscale_offsets) - a = ops.shuffle_rows(a, a_map) + a = ops.shuffle_rows(hidden_states, a_map) rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( a, @@ -481,13 +440,12 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, w1_blockscale, w1_alphas, problem_sizes1, expert_offsets[:-1], blockscale_offsets[:-1], - out_dtype, device) + output.dtype, device) del rep_a_fp4, rep_a_blockscale # hidden size dimension is split to one halfpytho sized tensor. intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), device=device, - dtype=out_dtype) - + dtype=output.dtype) torch.ops._C.silu_and_mul(intermediate, c1) int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( @@ -495,10 +453,146 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, w2_alphas, problem_sizes2, expert_offsets[:-1], - blockscale_offsets[:-1], out_dtype, device) + blockscale_offsets[:-1], output.dtype, device) del int_fp4, int_blockscale c2 = ops.shuffle_rows(c2, c_map) - out = (c2.view(m, num_topk, k) * - topk_weights.view(m, num_topk, 1).half()).sum(dim=1) - return out.to(dtype=out_dtype) + out = (c2.view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half()).sum(dim=1) + output.copy_(out.to(dtype=output.dtype), non_blocking=True) + + +class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( + self, + max_experts_per_worker: int, + out_dtype: torch.dtype, + ): + super().__init__() + self.max_experts_per_worker = max_experts_per_worker + self.out_dtype = out_dtype + + def supports_chunking(self) -> bool: + return True + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # Workspace1: for c1, Workspace2: for intermediate, Output: for final output + workspace1 = (M * topk, max(2 * N, K)) + workspace2 = (M * topk, N) + output = (M, K) + return (workspace1, workspace2, output, self.out_dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_blockscale: Optional[torch.Tensor], + w2_blockscale: Optional[torch.Tensor], + w1_alphas: Optional[torch.Tensor], + w2_alphas: Optional[torch.Tensor], + a1_gscale: Optional[torch.Tensor], + a2_gscale: Optional[torch.Tensor], + expert_num_tokens: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor] = None, + m: Optional[int] = None, + n: Optional[int] = None, + k: Optional[int] = None, + e: Optional[int] = None, + device: Optional[torch.device] = None, + ): + # All required args must be provided + assert w1_blockscale is not None and w2_blockscale is not None + assert w1_alphas is not None and w2_alphas is not None + assert a1_gscale is not None and a2_gscale is not None + assert topk_weights is not None + assert m is not None and n is not None and k is not None and e is not None + assert device is not None + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "ModelOptNvFp4FusedMoE.") + run_cutlass_moe_fp4( + output, + hidden_states, + w1_fp4, + w1_blockscale, + w1_alphas, + w2_fp4, + w2_blockscale, + w2_alphas, + topk_ids, + a1_gscale, + a2_gscale, + topk_weights, + m, + n, + k, + e, + device, + ) + + +def cutlass_moe_fp4( + a: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, +) -> torch.Tensor: + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP( + quant_dtype=torch.uint8, # FP4 packed in uint8 + per_channel_quant=False, + ), + CutlassExpertsFp4( + max_experts_per_worker=e, + out_dtype=a.dtype, + ), + ) + return fn( + a, + w1_fp4, + w2_fp4, + topk_weights, + topk_ids, + False, + "silu", + e, + None, + w1_blockscale, + w2_blockscale, + g1_alphas, + g2_alphas, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + m=m, + n=n, + k=k, + e=e, + device=device, + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py new file mode 100644 index 000000000000..f380e2847b75 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Dict + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( + FlashInferCutlassMoEPrepareAndFinalize) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig + +from vllm.utils import round_up + +logger = init_logger(__name__) + +from typing import TYPE_CHECKING + +try: + from flashinfer import fp4_quantize as fp4_quantize + from flashinfer.fused_moe import cutlass_fused_moe as cutlass_fused_moe +except ImportError: + if not TYPE_CHECKING: + cutlass_fused_moe = None + +has_flashinfer_cutlass_fused_moe = cutlass_fused_moe is not None + +#TODO(shuw): use this check +def _valid_flashinfer_fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor) -> bool: + """ + Check if the given problem size is supported by the DeepGemm grouped + gemm kernel. All of M, N, K and the quantization block_shape must be + aligned by `dg.get_m_alignment_for_contiguous_layout()`. + """ + if not has_flashinfer_cutlass_fused_moe: + logger.debug( + "FlashInferExperts disabled: flashinfer_cutlass_fused_moe not available." + ) + return False + # Data type checks + if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8 + or hidden_states.dtype + not in [torch.float32, torch.float16, torch.bfloat16]): + logger.debug( + f"FlashInferExperts disabled: w1/w2 must be torch.uint8 (got w1={w1.dtype}, w2={w2.dtype}), " + f"hidden_states must be float32, float16, or bfloat16 (got {hidden_states.dtype})." + ) + return False + return True + +class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self, + use_nvfp4_w4a4: bool = False, + use_fp8_w8a8: bool = False, + use_dp: bool=False, + ep_rank: int=0, + ep_size: int=1, + tp_rank: int=0, + tp_size: int=1, + ): + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.uint8, + per_act_token_quant=False, + block_shape=None, + )) + self.use_nvfp4_w4a4 = use_nvfp4_w4a4 + self.use_fp8_w8a8 = use_fp8_w8a8 + self.ep_rank=ep_rank + self.ep_size=ep_size + self.tp_rank=tp_rank + self.tp_size=tp_size + self.use_dp=use_dp + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_expert_map(self) -> bool: + return False + + def supports_chunking(self) -> bool: + #TODO(shuw): support chunking later + return False + + def workspace_shapes( + self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, + topk: int, global_num_experts: int, local_num_experts: int + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # We use global_num_experts due to how moe_align_block_size handles + # expert_maps. + """ + Compute the shapes for the temporary and final outputs of the two gemms + and activation in the fused expert function. Since the gemms are + independent, the workspace for the first gemm can be shared with the + workspace for the last gemm. + + Returns a tuple of: + - workspace13 shape tuple: must be large enough to hold the + result of either expert gemm. + - workspace2 shape tuple: must be large enough to hold the + result of the activation function. + - output shape tuple: must be exact size of the final gemm output. + - Workspace type: The dtype to use for the workspace tensors. + - Note: in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens. + """ + # num_experts = global_num_experts + # block_m = self.block_shape[0] + # M_sum = (M * topk) + num_experts * (block_m - 1) + # M_sum = round_up(M_sum, block_m) + # workspace1 = () + workspace2 = () + output_shape = a.shape + workspace_dtype = a.dtype + workspace1 = output_shape + + return (workspace1, workspace2, output_shape, workspace_dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13:Optional[torch.Tensor], + workspace2:Optional[torch.Tensor], + expert_num_tokens: Optional[torch.Tensor], + topk_weights: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_scale: torch.Tensor, + out_dtype: torch.dtype, + ): + # Flashinfer CUTLASS kernel takes scalar global scales, + # min because inv_scale. + if self.use_nvfp4_w4a4: + quant_scales = [ + torch.min(a1_scale), + w1_scale.view(torch.int32), + g1_alphas, + torch.min(a2_scale), + w2_scale.view(torch.int32), + g2_alphas, + ] + output = cutlass_fused_moe( + hidden_states, + topk_ids.to(torch.int), + topk_weights, + # FlashInfer API requires weight to be long for nvfp4 + w1.view(torch.long), + w2.view(torch.long), + output_dtype=out_dtype, + quant_scales=quant_scales, + input_sf=a1q_scale, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + )[0] + else: + raise ValueError("Only nvfp4 quantization is currently supported.") diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py new file mode 100644 index 000000000000..a023a181f116 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch +from flashinfer import fp4_swizzle_blockscale + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_dp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) + +def get_local_sizes(): + cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu + sizes = [cu_sizes[0].item()] + for i in range(1, len(cu_sizes)): + sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item()) + return sizes + + +class FlashInferCutlassMoEPrepareAndFinalize(MoEPrepareAndFinalizeNoEP + ): + + def __init__( + self, + quant_dtype: Optional[torch.dtype] = None, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None, + ): + super().__init__() + self.per_channel_quant = per_channel_quant + self.block_shape = block_shape + self.quant_dtype = quant_dtype + + def max_num_tokens_per_rank(self) -> Optional[int]: + return None + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return None + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool = False, + use_dp: bool = True, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + a1_scale, + self.quant_dtype, + self.per_channel_quant, + self.block_shape, + is_sf_swizzled_layout= + not use_dp, # Needs swizzling after communication + ) + if use_dp: + topk_weights, topk_ids, a1q, a1q_scale = \ + get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], + dim=0, + sizes=get_local_sizes()) + a1_m, a1_n = a1q.shape + a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2) + + return a1q, a1q_scale, None, topk_ids, topk_weights + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + use_dp: bool = False, + ) -> None: + if use_dp: + fused_expert_output = get_dp_group().reduce_scatter( + fused_expert_output, + dim=0, + sizes=get_local_sizes(), + ) + output.copy_(fused_expert_output) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6f9770262856..ecfceb9e57e3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -36,6 +36,17 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx +from typing import TYPE_CHECKING + +try: + from flashinfer import fp4_quantize as fp4_quantize + from flashinfer.fused_moe import ( + cutlass_fused_moe as flashinfer_cutlass_fused_moe) +except ImportError: + if not TYPE_CHECKING: + flashinfer_cutlass_fused_moe = None +has_flashinfer = flashinfer_cutlass_fused_moe is not None + if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts from .fused_moe import TritonExperts, fused_experts @@ -46,6 +57,8 @@ from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) + if has_flashinfer: + from .flashinfer_cutlass_prepare_finalize import FlashInferCutlassMoEPrepareAndFinalize else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -76,6 +89,9 @@ class FusedMoEMethodBase(QuantizeMethodBase): moe: FusedMoEConfig + def select_experts_impl(self, moe_parallel_config): + pass + @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -91,6 +107,10 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None + if moe.use_flashinfer_cutlass_kernels: + prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize( + quant_dtype=moe.quant_dtype, + ) if moe.use_pplx_kernels: hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, @@ -743,7 +763,9 @@ def __init__( quant_method: Optional[QuantizeMethodBase] = None quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None else quant_config.get_quant_method(self, prefix)) - + + quant_method.select_experts_impl(self.moe_parallel_config) + assert quant_method is not None assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method @@ -835,6 +857,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_flashinfer_cutlass_kernels(self): + return self.moe_parallel_config.use_flashinfer_cutlass_kernels + def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -905,12 +931,17 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. - if shard_id == "w1": - expert_data = expert_data.narrow(shard_dim, 0, shard_size) # w3, up_proj: Load into second logical weight of w13. + # trtllm cutlass kernel assumes differently + assert shard_id in ("w1", "w3") + switch_w13 = getattr(self.quant_method, 'load_up_proj_weight_first', + False) + if (switch_w13 and shard_id == "w1") or (not switch_w13 + and shard_id == "w3"): + start = shard_size else: - assert shard_id == "w3" - expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + start = 0 + expert_data = expert_data.narrow(shard_dim, start, shard_size) expert_data.copy_(loaded_weight) def _load_w2(self, @@ -1390,7 +1421,10 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): final_hidden_states, non_blocking=True) ctx = get_forward_context() - max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + #TODO(shuw):where is it? + # flashinfer_cutlass_kernels can handle TP+EP without DP + max_tokens_across_dp = (MOE_DP_CHUNK_SIZE if self.dp_size == 1 else + ctx.dp_metadata.max_tokens_across_dp_cpu) moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens num_tokens = full_hidden_states.size(0) @@ -1418,7 +1452,8 @@ def forward_impl(self, hidden_states: torch.Tensor, do_naive_dispatch_combine: bool = ( self.dp_size > 1 - and not self.moe_parallel_config.use_deepep_ht_kernels) + and not self.moe_parallel_config.use_deepep_ht_kernels + and not self.moe_parallel_config.use_flashinfer_cutlass_kernels) if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 2ffb4d328eca..8b536caf89c6 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -410,6 +410,9 @@ def forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, + extra_expert_args: Optional[dict] = None, + extra_prepare_args: Optional[dict] = None, + extra_finalize_args: Optional[dict] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -442,6 +445,9 @@ def forward( - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. + - extra_expert_args (Optional[dict]): Extra keyword arguments to pass to fused_experts.apply. + - extra_prepare_args (Optional[dict]): Extra keyword arguments to pass to prepare. + - extra_finalize_args (Optional[dict]): Extra keyword arguments to pass to finalize. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -454,18 +460,12 @@ def forward( if global_num_experts == -1: global_num_experts = local_num_experts + prepare_kwargs = extra_prepare_args or {} (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) + a1, a1_scale, a2_scale, topk_weights, topk_ids, + global_num_experts, expert_map, apply_router_weight_on_input, + **prepare_kwargs) # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids @@ -517,10 +517,13 @@ def forward( workspace2 = torch.empty(prod(workspace2_shape), device=a1.device, dtype=workspace_dtype) - + expert_kwargs = extra_expert_args or {} + import pdb + # pdb.set_trace() if num_chunks == 1: fused_out = _resize_cache(workspace13, fused_out_shape) - + if 'topk_weights' in expert_kwargs: + expert_kwargs['topk_weights'] = topk_weights self.fused_experts.apply( fused_out, a1q, @@ -539,6 +542,7 @@ def forward( workspace13=workspace13, workspace2=workspace2, expert_num_tokens=expert_num_tokens, + **expert_kwargs, ) else: # The leading output dimension may not be equal to M, so @@ -586,9 +590,11 @@ def forward( workspace13=workspace13, workspace2=workspace2, expert_num_tokens=expert_num_tokens, + **expert_kwargs, ) - + finalize_kwargs = extra_finalize_args or {} self.prepare_finalize.finalize(output, fused_out, topk_weights, - topk_ids, apply_router_weight_on_input) + topk_ids, apply_router_weight_on_input, + **finalize_kwargs) return output diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 52346f797440..475c7f7bf4b6 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -4,6 +4,7 @@ from typing import Optional import torch +from flashinfer import fp4_quantize as fp4_quantize from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -23,6 +24,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: return x.flatten()[:prod(v)].view(*v) +def _fp4_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + is_sf_swizzled_layout: bool, +) -> tuple[torch.Tensor]: + return fp4_quantize(A, + A_scale, + is_sf_swizzled_layout=is_sf_swizzled_layout) + + def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], @@ -80,11 +91,14 @@ def moe_kernel_quantize_input( quant_dtype: Optional[torch.dtype], per_act_token_quant: bool, block_shape: Optional[list[int]] = None, + is_sf_swizzled_layout: bool = True, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if quant_dtype == torch.float8_e4m3fn: return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == torch.uint8: # nvfp4 + return _fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout) else: return A, A_scale diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index a10911b84afc..d9789515c642 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,15 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, Union +import functools +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch.nn import Module from torch.nn.parameter import Parameter +import vllm.envs as envs from vllm._custom_ops import (cutlass_scaled_fp4_mm, cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -29,6 +32,24 @@ PerTensorScaleParameter) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( + FlashInferCutlassMoEPrepareAndFinalize) +import vllm.model_executor.layers.fused_moe.modular_kernel as mk + +from vllm.distributed import ( + get_dp_group, get_ep_group, get_tensor_model_parallel_world_size) + +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +try: + from flashinfer import fp4_quantize as fp4_quantize + from flashinfer.fused_moe import ( + cutlass_fused_moe as flashinfer_cutlass_fused_moe) +except ImportError: + if not TYPE_CHECKING: + flashinfer_cutlass_fused_moe = None logger = init_logger(__name__) @@ -462,6 +483,21 @@ def __init__(self, quant_config: ModelOptNvFp4Config): self.quant_config = quant_config self.cutlass_nvfp4_supported = cutlass_fp4_supported() self.use_marlin = False + self.allow_flashinfer_cutlass = False + + if envs.VLLM_USE_FLASHINFER_MOE: + if not self.cutlass_nvfp4_supported: + logger.warning_once( + "Failed to import Flashinfer CUTLASS Fused MoE kernels.") + elif (current_platform.is_cuda() + and current_platform.has_device_capability(10, 0)): + logger.info_once( + "Using FlashInfer kernels for ModelOptNvFp4FusedMoE.") + self.allow_flashinfer_cutlass = True + else: + logger.warning_once( + "Flashinfer CUTLASS Fused MoE not supported on the current platform." + ) if not self.cutlass_nvfp4_supported: if is_fp4_marlin_supported(): @@ -470,6 +506,83 @@ def __init__(self, quant_config: ModelOptNvFp4Config): raise ValueError("Current platform does not support NVFP4" " quantization. Please use Blackwell and" " above.") + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + + self.fused_experts = cutlass_moe_fp4 + + @property + def load_up_proj_weight_first(self) -> bool: + # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 + if self.allow_flashinfer_cutlass: + return True + return False + + def select_experts_impl(self, moe_parallel_config): + if not self.allow_flashinfer_cutlass: + return + + logger.debug("FlashInferExperts") + # default to TP/EP case only + + experts_kwargs = { + "use_nvfp4_w4a4": True, + "use_dp": moe_parallel_config.dp_size > 1, + } + if not moe_parallel_config.dp_size > 1 and moe_parallel_config.use_ep: + experts_kwargs["ep_rank"] = moe_parallel_config.ep_rank + experts_kwargs["ep_size"] = moe_parallel_config.ep_size + experts_kwargs["tp_rank"] = moe_parallel_config.tp_rank + experts_kwargs["tp_size"] = moe_parallel_config.tp_size + experts = FlashInferExperts(**experts_kwargs) + self.fused_experts = mk.FusedMoEModularKernel( + FlashInferCutlassMoEPrepareAndFinalize( + quant_dtype=torch.uint8, + #meaning 2x e2m1 packed in one, kernel requirement + ), + experts, + ) + + @property + def load_up_proj_weight_first(self) -> bool: + # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 + if self.allow_flashinfer_cutlass: + return True + return False + + # This method update self.fused_experts + # only prepare_finalize is not None call select_gemm_impl + # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert + # when it's not called(TP case), we still have 2 kernels to use. + def select_gemm_impl(self, prepare_finalize, moe): + + assert moe is not None + assert prepare_finalize is not None + experts = None + # print("fffff"*100) + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + if self.allow_flashinfer_cutlass: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts) + logger.debug("FlashInferExperts %s", moe) + # assert moe.dp_size == all2all_manager.dp_world_size + experts = FlashInferExperts( + use_nvfp4_w4a4=True, + use_dp=moe.moe_parallel_config.dp_size>1, + ep_rank=moe.moe_parallel_config.ep_rank, + ep_size=moe.moe_parallel_config.ep_size, + tp_rank=moe.moe_parallel_config.tp_rank, + tp_size=moe.moe_parallel_config.tp_size, + ) + else: + assert moe.dp_size > 1 + logger.debug("CutlassExpertsFp4 %s", moe) + # current doesn't support DP + raise ValueError("CutlassExpertsFp4 Doesn't support DP. " + "Use flashinfer CUTLASS FusedMoE backend instead.") + + return experts def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -707,9 +820,6 @@ def apply( assert not apply_router_weight_on_input, ( "Router weight on input is not " "supported for ModelOptNvFp4FusedMoE.") - assert expert_map is None, ("Expert Parallelism / expert_map " - "is currently not supported for " - "ModelOptNvFp4FusedMoE.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -722,25 +832,55 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) + + if self.allow_flashinfer_cutlass: + # TP or DP case + # import pdb + # pdb.set_trace() + extra_expert_args = { + 'topk_weights': None, #placeholder topk_weights, + 'g1_alphas': layer.g1_alphas, + 'g2_alphas': layer.g2_alphas, + 'out_dtype': x.dtype, + 'a1_scale': layer.w13_input_scale_quant, + } + extra_prepare_args = { + 'use_dp': layer.dp_size > 1, + } + out = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + w1_scale=layer.w13_blockscale_swizzled, + w2_scale=layer.w2_blockscale_swizzled, + a1_scale=layer.w13_input_scale_quant, + a2_scale=layer.w2_input_scale_quant, + extra_expert_args=extra_expert_args, + extra_prepare_args=extra_prepare_args, + ) + else: # cutlass_moe_fp4, TP case only + out = self.fused_experts( + a=x, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + device=x.device + ).to(x.dtype) + return out - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4) - - # Cutlass moe takes in activations in BF16/Half precision - # and fp4 quantized weights loaded from the checkpoint - return cutlass_moe_fp4(a=x, - w1_fp4=layer.w13_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w1_alphas=layer.g1_alphas, - w2_fp4=layer.w2_weight, - w2_blockscale=layer.w2_blockscale_swizzled, - w2_alphas=layer.g2_alphas, - topk_weights=topk_weights, - topk_ids=topk_ids, - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - device=x.device).to(x.dtype) From 11f9136280bf4480611f7a4e1c82846b2b0c6e7f Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 4 Jul 2025 19:19:33 +0000 Subject: [PATCH 02/30] cutlass mooe path work --- .../layers/fused_moe/cutlass_moe.py | 246 +++++++++++++----- .../flashinfer_cutlass_prepare_finalize.py | 10 +- .../layers/fused_moe/modular_kernel.py | 21 +- .../layers/fused_moe/prepare_finalize.py | 21 +- .../layers/quantization/modelopt.py | 24 +- 5 files changed, 227 insertions(+), 95 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 213767867ff5..8138a37de128 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -398,36 +398,92 @@ def cutlass_moe_fp8( def run_cutlass_moe_fp4( output: torch.Tensor, - hidden_states: torch.Tensor, + a: torch.Tensor, + a1_gscale: torch.Tensor, w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor, w1_alphas: torch.Tensor, + a2_gscale: torch.Tensor, w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor, w2_alphas: torch.Tensor, - topk_ids: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, m: int, n: int, k: int, e: int, device: torch.device, ): - num_topk = topk_ids.shape[1] + """ + MoE implementation for FP4 Inputs + + # Gemm 1 + a: Input tensor: [m, k] (half/bfloat16) + a1_gscale: Activation scale per expert: [e] (float32) + w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k] + w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1) + (Note: `n` is the up projection output dim, `k` is the input dim in + full precision) + w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) + (Block size = 16 for NVFP4) + + # Gemm 2 + a2_gscale: Activation scale per expert: [e] + w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n] + w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1) + w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3 + + topk_weights: [m, topk] dtype: float8 + topk_ids: [m, topk] dtype: float8 + + m, n, k: Unquantized weight shapes, dtype: int + e: number of experts, dtype: int + + assumes that topk < k < n to satisfy - up/down projection expectations. + """ + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8" + assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8" + assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3 + and w2_blockscale.ndim + == 3), ("All Weights must be of rank 3 for cutlass_moe_fp4") + m_a, k_a = a.shape + e_w1, nx2_w1, half_k_w1 = w1_fp4.shape + e_w2, k_w2, half_n_w2 = w2_fp4.shape + + assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", + " between weights.") + assert (k_a // 2 == half_k_w1 + and k == k_w2), ("Hidden size mismatch between a, w1 and w2") + assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in " + "expected `n`") + assert (m == m_a), "input shape mismatch" + assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" + assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" + assert (topk_weights.size(0) == m and topk_ids.size(0) + == m), ("topk must be provided for each row of a") + + out_dtype = a.dtype + num_topk = topk_ids.size(1) + expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,2n,k)) problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,n,k)) problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device) + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + # problem shapes should have [m, n, k] + # Note that problem sizes are based on logical number of elements. ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, a_map, c_map, e, n, k, blockscale_offsets) - a = ops.shuffle_rows(hidden_states, a_map) + a = ops.shuffle_rows(a, a_map) rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( a, @@ -440,12 +496,13 @@ def run_cutlass_moe_fp4( c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, w1_blockscale, w1_alphas, problem_sizes1, expert_offsets[:-1], blockscale_offsets[:-1], - output.dtype, device) + out_dtype, device) del rep_a_fp4, rep_a_blockscale # hidden size dimension is split to one halfpytho sized tensor. intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), device=device, - dtype=output.dtype) + dtype=out_dtype) + torch.ops._C.silu_and_mul(intermediate, c1) int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( @@ -453,12 +510,14 @@ def run_cutlass_moe_fp4( c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, w2_alphas, problem_sizes2, expert_offsets[:-1], - blockscale_offsets[:-1], output.dtype, device) + blockscale_offsets[:-1], out_dtype, device) del int_fp4, int_blockscale c2 = ops.shuffle_rows(c2, c_map) - out = (c2.view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half()).sum(dim=1) - output.copy_(out.to(dtype=output.dtype), non_blocking=True) + assert output.dtype == out_dtype + output.copy_((c2.view(m, num_topk, k) * + topk_weights.view(m, num_topk, 1).half()).sum(dim=1)) + return class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): @@ -466,13 +525,39 @@ def __init__( self, max_experts_per_worker: int, out_dtype: torch.dtype, + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]] = None, + use_batched_format: bool = False, ): - super().__init__() + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.uint8, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + ) + ) self.max_experts_per_worker = max_experts_per_worker self.out_dtype = out_dtype + self.use_batched_format = use_batched_format + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + if self.use_batched_format: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + else: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_expert_map(self) -> bool: + return False def supports_chunking(self) -> bool: - return True + return False def workspace_shapes( self, @@ -486,58 +571,57 @@ def workspace_shapes( local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Workspace1: for c1, Workspace2: for intermediate, Output: for final output - workspace1 = (M * topk, max(2 * N, K)) - workspace2 = (M * topk, N) - output = (M, K) + workspace1 = a.shape + workspace2 = () + output = a.shape return (workspace1, workspace2, output, self.out_dtype) def apply( self, output: torch.Tensor, hidden_states: torch.Tensor, - w1_fp4: torch.Tensor, - w2_fp4: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_blockscale: Optional[torch.Tensor], - w2_blockscale: Optional[torch.Tensor], - w1_alphas: Optional[torch.Tensor], - w2_alphas: Optional[torch.Tensor], - a1_gscale: Optional[torch.Tensor], - a2_gscale: Optional[torch.Tensor], + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: torch.Tensor, + workspace13: Optional[torch.Tensor], + workspace2: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor], - topk_weights: Optional[torch.Tensor] = None, - m: Optional[int] = None, - n: Optional[int] = None, - k: Optional[int] = None, - e: Optional[int] = None, - device: Optional[torch.device] = None, + # start of extra_expert_args + topk_weights: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_scale: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, ): - # All required args must be provided - assert w1_blockscale is not None and w2_blockscale is not None - assert w1_alphas is not None and w2_alphas is not None - assert a1_gscale is not None and a2_gscale is not None - assert topk_weights is not None - assert m is not None and n is not None and k is not None and e is not None - assert device is not None assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " "ModelOptNvFp4FusedMoE.") run_cutlass_moe_fp4( output, hidden_states, - w1_fp4, - w1_blockscale, - w1_alphas, - w2_fp4, - w2_blockscale, - w2_alphas, - topk_ids, - a1_gscale, - a2_gscale, + a1_scale, + w1, + w1_scale, + g1_alphas, + a2_scale, + w2, + w2_scale, + g2_alphas, topk_weights, + topk_ids, m, n, k, @@ -554,8 +638,8 @@ def cutlass_moe_fp4( w2_blockscale: torch.Tensor, g1_alphas: torch.Tensor, g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, + a1_scale: torch.Tensor, + a2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, @@ -563,36 +647,54 @@ def cutlass_moe_fp4( k: int, e: int, device: torch.device, + expert_map: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "ModelOptNvFp4FusedMoE's cutlass_moe_fp4.") fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP( - quant_dtype=torch.uint8, # FP4 packed in uint8 - per_channel_quant=False, - ), + MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp4( max_experts_per_worker=e, out_dtype=a.dtype, + per_act_token_quant=False, + per_out_ch_quant=False, + use_batched_format=False, ), ) + extra_expert_args = { + 'topk_weights': topk_weights, + 'g1_alphas': g1_alphas, + 'g2_alphas': g2_alphas, + 'a1_scale': a1_scale, + 'm': m, + 'n': n, + 'k': k, + 'e': e, + 'device': device, + } + extra_prepare_args = { + 'skip_quant': True, + } + extra_finalize_args = { + 'skip_permute_reduce': True, + } return fn( - a, - w1_fp4, - w2_fp4, - topk_weights, - topk_ids, - False, - "silu", - e, - None, - w1_blockscale, - w2_blockscale, - g1_alphas, - g2_alphas, - a1_gscale=a1_gscale, - a2_gscale=a2_gscale, - m=m, - n=n, - k=k, - e=e, - device=device, + hidden_states=a, + w1=w1_fp4, + w2=w2_fp4, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation="silu", + global_num_experts=e, + expert_map=None, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + a1_scale=a1_scale, + a2_scale=a2_scale, + apply_router_weight_on_input=False, + extra_expert_args=extra_expert_args, + extra_prepare_args=extra_prepare_args, + extra_finalize_args=extra_finalize_args, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index a023a181f116..ab12ca8ec121 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -4,6 +4,7 @@ import torch from flashinfer import fp4_swizzle_blockscale +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.distributed import get_dp_group @@ -20,7 +21,7 @@ def get_local_sizes(): sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item()) return sizes - +#should be mk.FusedMoEPrepareAndFinalize chk pplx class FlashInferCutlassMoEPrepareAndFinalize(MoEPrepareAndFinalizeNoEP ): @@ -50,8 +51,9 @@ def prepare( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool = False, - use_dp: bool = True, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + use_dp: Optional[bool] = True, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -65,7 +67,7 @@ def prepare( a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, - self.quant_dtype, + quant_config.quant_dtype, self.per_channel_quant, self.block_shape, is_sf_swizzled_layout= diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 8b536caf89c6..4437e72c3b9f 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -461,11 +461,21 @@ def forward( global_num_experts = local_num_experts prepare_kwargs = extra_prepare_args or {} + # import pdb + # pdb.set_trace() (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( - a1, a1_scale, a2_scale, topk_weights, topk_ids, - global_num_experts, expert_map, apply_router_weight_on_input, - **prepare_kwargs) + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + **prepare_kwargs, + ) # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids @@ -518,11 +528,10 @@ def forward( device=a1.device, dtype=workspace_dtype) expert_kwargs = extra_expert_args or {} - import pdb - # pdb.set_trace() + if num_chunks == 1: fused_out = _resize_cache(workspace13, fused_out_shape) - if 'topk_weights' in expert_kwargs: + if 'topk_weights' in expert_kwargs and expert_kwargs['topk_weights'] is None: expert_kwargs['topk_weights'] = topk_weights self.fused_experts.apply( fused_out, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 9e4be82f6c1f..ec67e6984df5 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -35,6 +35,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + skip_quant: Optional[bool]=False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -44,13 +45,18 @@ def prepare( assert topk == 1, \ "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - + + if skip_quant: + # print("skiped_quant"*10) + # print(f"skiped_quant:{skip_quant}") + return a1, None, None, None, None a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) return a1q, a1q_scale, None, None, None + def finalize( self, output: torch.Tensor, @@ -58,6 +64,15 @@ def finalize( topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, + skip_permute_reduce: Optional[bool]=False, ) -> None: - _moe_unpermute_and_reduce(output, fused_expert_output, None, - topk_weights, apply_router_weight_on_input) + # print("skip_permute_reduce"*10) + # print(f"skip_permute_reduce:{skip_permute_reduce}") + if skip_permute_reduce: + assert output.shape == fused_expert_output.shape + output.copy_(fused_expert_output) + else: + _moe_unpermute_and_reduce( + output, fused_expert_output, None, + topk_weights, apply_router_weight_on_input + ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index d9789515c642..5bfc761a2066 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -520,7 +520,7 @@ def load_up_proj_weight_first(self) -> bool: def select_experts_impl(self, moe_parallel_config): if not self.allow_flashinfer_cutlass: - return + return logger.debug("FlashInferExperts") # default to TP/EP case only @@ -863,24 +863,28 @@ def apply( extra_expert_args=extra_expert_args, extra_prepare_args=extra_prepare_args, ) - else: # cutlass_moe_fp4, TP case only + else: + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + run_cutlass_moe_fp4) + + # cutlass_moe_fp4, TP case only(no EP) out = self.fused_experts( a=x, w1_fp4=layer.w13_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w1_alphas=layer.g1_alphas, w2_fp4=layer.w2_weight, + w1_blockscale=layer.w13_blockscale_swizzled, w2_blockscale=layer.w2_blockscale_swizzled, - w2_alphas=layer.g2_alphas, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_scale=layer.w13_input_scale_quant, + a2_scale=layer.w2_input_scale_quant, topk_weights=topk_weights, topk_ids=topk_ids, m=x.shape[0], n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - device=x.device - ).to(x.dtype) + device=x.device, + expert_map=expert_map, + ) return out - From 92994d5377c0ebbe9eefeb346e1085fc7aa44e0e Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 4 Jul 2025 22:19:15 +0000 Subject: [PATCH 03/30] flashinfer tp pass --- .../fused_moe/flashinfer_cutlass_moe.py | 14 ++++--- .../flashinfer_cutlass_prepare_finalize.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 4 +- .../layers/fused_moe/modular_kernel.py | 2 + vllm/model_executor/layers/fused_moe/utils.py | 1 + .../layers/quantization/modelopt.py | 38 ++++++++++++++++--- 6 files changed, 48 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index f380e2847b75..8ebce10f77e4 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -18,12 +18,12 @@ try: from flashinfer import fp4_quantize as fp4_quantize - from flashinfer.fused_moe import cutlass_fused_moe as cutlass_fused_moe + from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe except ImportError: if not TYPE_CHECKING: cutlass_fused_moe = None -has_flashinfer_cutlass_fused_moe = cutlass_fused_moe is not None +has_flashinfer_cutlass_fused_moe = flashinfer_cutlass_fused_moe is not None #TODO(shuw): use this check def _valid_flashinfer_fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, @@ -157,8 +157,9 @@ def apply( torch.min(a2_scale), w2_scale.view(torch.int32), g2_alphas, - ] - output = cutlass_fused_moe( + ] + # print(self.ep_size, self.ep_rank, self.tp_rank, self.tp_size) + out = flashinfer_cutlass_fused_moe( hidden_states, topk_ids.to(torch.int), topk_weights, @@ -168,10 +169,11 @@ def apply( output_dtype=out_dtype, quant_scales=quant_scales, input_sf=a1q_scale, - ep_size=self.ep_size, - ep_rank=self.ep_rank, tp_size=self.tp_size, tp_rank=self.tp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, )[0] + output.copy_(out) else: raise ValueError("Only nvfp4 quantization is currently supported.") diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index ab12ca8ec121..58ada401b028 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -66,7 +66,7 @@ def prepare( a1q, a1q_scale = moe_kernel_quantize_input( a1, - a1_scale, + torch.min(a1_scale), # special to nvfp4 quant_config.quant_dtype, self.per_channel_quant, self.block_shape, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ecfceb9e57e3..9af5dff999b4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1454,6 +1454,7 @@ def forward_impl(self, hidden_states: torch.Tensor, self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_parallel_config.use_flashinfer_cutlass_kernels) + # do_naive_dispatch_combine = True if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) @@ -1483,8 +1484,9 @@ def forward_impl(self, hidden_states: torch.Tensor, if do_naive_dispatch_combine: final_hidden_states = get_ep_group().combine(final_hidden_states) - + # print(f"reduce_results:{self.reduce_results}") if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # if True and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs. final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( final_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 4437e72c3b9f..68172f8f8b5f 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -533,6 +533,8 @@ def forward( fused_out = _resize_cache(workspace13, fused_out_shape) if 'topk_weights' in expert_kwargs and expert_kwargs['topk_weights'] is None: expert_kwargs['topk_weights'] = topk_weights + assert expert_kwargs['topk_weights'] is not None + # print("replacing topk_weights") self.fused_experts.apply( fused_out, a1q, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 475c7f7bf4b6..4f18e6e3e416 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -98,6 +98,7 @@ def moe_kernel_quantize_input( elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.uint8: # nvfp4 + # print(f"calling fp4 quantize:{is_sf_swizzled_layout}") return _fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout) else: return A, A_scale diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 5bfc761a2066..752aa068b0d7 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -529,11 +529,11 @@ def select_experts_impl(self, moe_parallel_config): "use_nvfp4_w4a4": True, "use_dp": moe_parallel_config.dp_size > 1, } - if not moe_parallel_config.dp_size > 1 and moe_parallel_config.use_ep: - experts_kwargs["ep_rank"] = moe_parallel_config.ep_rank - experts_kwargs["ep_size"] = moe_parallel_config.ep_size - experts_kwargs["tp_rank"] = moe_parallel_config.tp_rank - experts_kwargs["tp_size"] = moe_parallel_config.tp_size + # if not moe_parallel_config.dp_size > 1 and moe_parallel_config.use_ep: + experts_kwargs["ep_rank"] = moe_parallel_config.ep_rank + experts_kwargs["ep_size"] = moe_parallel_config.ep_size + experts_kwargs["tp_rank"] = moe_parallel_config.tp_rank + experts_kwargs["tp_size"] = moe_parallel_config.tp_size experts = FlashInferExperts(**experts_kwargs) self.fused_experts = mk.FusedMoEModularKernel( FlashInferCutlassMoEPrepareAndFinalize( @@ -847,6 +847,33 @@ def apply( extra_prepare_args = { 'use_dp': layer.dp_size > 1, } + # from flashinfer import fp4_quantize as fp4_quantize + # from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe + # quant_scales = [ + # torch.min(layer.w13_input_scale_quant), + # layer.w13_blockscale_swizzled.view(torch.int32), + # layer.g1_alphas, + # torch.min(layer.w2_input_scale_quant), + # layer.w2_blockscale_swizzled.view(torch.int32), + # layer.g2_alphas, + # ] + # xq, input_sf = fp4_quantize(x, torch.min(layer.w13_input_scale_quant)) + # out = flashinfer_cutlass_fused_moe( + # xq, + # topk_ids, + # topk_weights, + # layer.w13_weight.view(torch.long), + # layer.w2_weight.view(torch.long), + # x.dtype, + # quant_scales, + # input_sf, + # self.fused_experts.fused_experts.tp_size, + # self.fused_experts.fused_experts.tp_rank, + # self.fused_experts.fused_experts.ep_size, + # self.fused_experts.fused_experts.ep_rank, + # )[0] + + # print(f"usedp:{layer.dp_size > 1}") out = self.fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -856,6 +883,7 @@ def apply( inplace=False, # TODO(shuw): fix later, now output is high prec activation=activation, global_num_experts=global_num_experts, + expert_map=expert_map, w1_scale=layer.w13_blockscale_swizzled, w2_scale=layer.w2_blockscale_swizzled, a1_scale=layer.w13_input_scale_quant, From 93a3a0a6c5fd077cec221ac290baa75a3307cfcd Mon Sep 17 00:00:00 2001 From: shuw Date: Sat, 5 Jul 2025 05:55:24 +0000 Subject: [PATCH 04/30] dp work --- vllm/distributed/parallel_state.py | 7 +++--- .../fused_moe/flashinfer_cutlass_moe.py | 14 +++++++---- .../flashinfer_cutlass_prepare_finalize.py | 7 ++++-- vllm/model_executor/layers/fused_moe/layer.py | 2 -- .../layers/fused_moe/modular_kernel.py | 2 ++ vllm/model_executor/layers/fused_moe/utils.py | 4 ++-- .../layers/quantization/modelopt.py | 23 +++++++++---------- 7 files changed, 34 insertions(+), 25 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f09e28c4edc2..964d13b61699 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -368,8 +368,9 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: return input_ assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - - if self.use_custom_op_call: + + # TODO(shuw): enable it + if self.use_custom_op_call and False: return torch.ops.vllm.all_gather(input_, dim, world_size, @@ -398,7 +399,7 @@ def reduce_scatter(self, assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - if self.use_custom_op_call: + if self.use_custom_op_call and False: assert sizes is None, "Varying size reduce scatter not supported with vllm custom op" return torch.ops.vllm.reduce_scatter(input_, dim, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 8ebce10f77e4..01c608e5c686 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -115,11 +115,14 @@ def workspace_shapes( # M_sum = (M * topk) + num_experts * (block_m - 1) # M_sum = round_up(M_sum, block_m) # workspace1 = () + # TODO(shuw): This is nvfp4 specialized, add branch for other quant type. + aq_m, aq_n = aq.shape workspace2 = () - output_shape = a.shape + output_shape = (aq_m, aq_n * 2) workspace_dtype = a.dtype workspace1 = output_shape - + # print(f"inside workspace_shape: workspace1:{workspace1} and output_shape:{output_shape} with type:{workspace_dtype}") + # determined by aq, since aq is the one after possible communication op and participate in experts computation. return (workspace1, workspace2, output_shape, workspace_dtype) def apply( @@ -151,10 +154,10 @@ def apply( # min because inv_scale. if self.use_nvfp4_w4a4: quant_scales = [ - torch.min(a1_scale), + a1_scale, w1_scale.view(torch.int32), g1_alphas, - torch.min(a2_scale), + a2_scale, w2_scale.view(torch.int32), g2_alphas, ] @@ -174,6 +177,9 @@ def apply( ep_size=self.ep_size, ep_rank=self.ep_rank, )[0] + # print(f"callsite hidden_states:{hidden_states.shape}") + # print(f"tmp:{out.shape}") + # print(f"output:{output.shape}") output.copy_(out) else: raise ValueError("Only nvfp4 quantization is currently supported.") diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 58ada401b028..e9f0b0c4d5dd 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -66,11 +66,11 @@ def prepare( a1q, a1q_scale = moe_kernel_quantize_input( a1, - torch.min(a1_scale), # special to nvfp4 + a1_scale, quant_config.quant_dtype, self.per_channel_quant, self.block_shape, - is_sf_swizzled_layout= + is_fp4_scalar_swizzled= not use_dp, # Needs swizzling after communication ) if use_dp: @@ -98,4 +98,7 @@ def finalize( dim=0, sizes=get_local_sizes(), ) + # print(f"use_dp: {use_dp}") + # print(f"output:{output.shape}") + # print(f"fused_expert_output:{fused_expert_output.shape}") output.copy_(fused_expert_output) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9af5dff999b4..ea16807e1b08 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1484,9 +1484,7 @@ def forward_impl(self, hidden_states: torch.Tensor, if do_naive_dispatch_combine: final_hidden_states = get_ep_group().combine(final_hidden_states) - # print(f"reduce_results:{self.reduce_results}") if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - # if True and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs. final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( final_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 68172f8f8b5f..2d4f012090d2 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -531,6 +531,8 @@ def forward( if num_chunks == 1: fused_out = _resize_cache(workspace13, fused_out_shape) + # print(f"fused_out_shape:{fused_out_shape}") + # print(f"a1q:{a1q.shape}") if 'topk_weights' in expert_kwargs and expert_kwargs['topk_weights'] is None: expert_kwargs['topk_weights'] = topk_weights assert expert_kwargs['topk_weights'] is not None diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 4f18e6e3e416..f1d6c6bfd047 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -91,7 +91,7 @@ def moe_kernel_quantize_input( quant_dtype: Optional[torch.dtype], per_act_token_quant: bool, block_shape: Optional[list[int]] = None, - is_sf_swizzled_layout: bool = True, + is_fp4_scalar_swizzled: bool = True, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if quant_dtype == torch.float8_e4m3fn: return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) @@ -99,7 +99,7 @@ def moe_kernel_quantize_input( return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.uint8: # nvfp4 # print(f"calling fp4 quantize:{is_sf_swizzled_layout}") - return _fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout) + return _fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scalar_swizzled) else: return A, A_scale diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 752aa068b0d7..eeeee43836c3 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -486,17 +486,14 @@ def __init__(self, quant_config: ModelOptNvFp4Config): self.allow_flashinfer_cutlass = False if envs.VLLM_USE_FLASHINFER_MOE: - if not self.cutlass_nvfp4_supported: - logger.warning_once( - "Failed to import Flashinfer CUTLASS Fused MoE kernels.") - elif (current_platform.is_cuda() - and current_platform.has_device_capability(10, 0)): + if self.cutlass_nvfp4_supported and current_platform.is_cuda() \ + and current_platform.has_device_capability(10, 0): logger.info_once( "Using FlashInfer kernels for ModelOptNvFp4FusedMoE.") self.allow_flashinfer_cutlass = True else: logger.warning_once( - "Flashinfer CUTLASS Fused MoE not supported on the current platform." + "Flashinfer CUTLASS Fused MoE not supported or found on the current platform." ) if not self.cutlass_nvfp4_supported: @@ -514,9 +511,7 @@ def __init__(self, quant_config: ModelOptNvFp4Config): @property def load_up_proj_weight_first(self) -> bool: # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 - if self.allow_flashinfer_cutlass: - return True - return False + return self.allow_flashinfer_cutlass def select_experts_impl(self, moe_parallel_config): if not self.allow_flashinfer_cutlass: @@ -842,11 +837,14 @@ def apply( 'g1_alphas': layer.g1_alphas, 'g2_alphas': layer.g2_alphas, 'out_dtype': x.dtype, - 'a1_scale': layer.w13_input_scale_quant, + 'a1_scale': torch.min(layer.w13_input_scale_quant), } extra_prepare_args = { 'use_dp': layer.dp_size > 1, } + extra_finalize_args = { + 'use_dp': layer.dp_size > 1, + } # from flashinfer import fp4_quantize as fp4_quantize # from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe # quant_scales = [ @@ -886,10 +884,11 @@ def apply( expert_map=expert_map, w1_scale=layer.w13_blockscale_swizzled, w2_scale=layer.w2_blockscale_swizzled, - a1_scale=layer.w13_input_scale_quant, - a2_scale=layer.w2_input_scale_quant, + a1_scale=torch.min(layer.w13_input_scale_quant), + a2_scale=torch.min(layer.w2_input_scale_quant), extra_expert_args=extra_expert_args, extra_prepare_args=extra_prepare_args, + extra_finalize_args=extra_finalize_args, ) else: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( From c1a74ebb4c2b466e8cbf2a23a28d2e4d8e95618a Mon Sep 17 00:00:00 2001 From: shuw Date: Sat, 5 Jul 2025 05:58:24 +0000 Subject: [PATCH 05/30] Clean up --- .../fused_moe/flashinfer_cutlass_moe.py | 6 +--- .../flashinfer_cutlass_prepare_finalize.py | 3 -- .../layers/fused_moe/modular_kernel.py | 3 -- vllm/model_executor/layers/fused_moe/utils.py | 1 - .../layers/quantization/modelopt.py | 29 +------------------ 5 files changed, 2 insertions(+), 40 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 01c608e5c686..688849c1406d 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -121,7 +121,6 @@ def workspace_shapes( output_shape = (aq_m, aq_n * 2) workspace_dtype = a.dtype workspace1 = output_shape - # print(f"inside workspace_shape: workspace1:{workspace1} and output_shape:{output_shape} with type:{workspace_dtype}") # determined by aq, since aq is the one after possible communication op and participate in experts computation. return (workspace1, workspace2, output_shape, workspace_dtype) @@ -161,7 +160,6 @@ def apply( w2_scale.view(torch.int32), g2_alphas, ] - # print(self.ep_size, self.ep_rank, self.tp_rank, self.tp_size) out = flashinfer_cutlass_fused_moe( hidden_states, topk_ids.to(torch.int), @@ -177,9 +175,7 @@ def apply( ep_size=self.ep_size, ep_rank=self.ep_rank, )[0] - # print(f"callsite hidden_states:{hidden_states.shape}") - # print(f"tmp:{out.shape}") - # print(f"output:{output.shape}") + # TODO(shuw): Handle the allocation from FlashInfer to framework output.copy_(out) else: raise ValueError("Only nvfp4 quantization is currently supported.") diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index e9f0b0c4d5dd..b6d13d03c275 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -98,7 +98,4 @@ def finalize( dim=0, sizes=get_local_sizes(), ) - # print(f"use_dp: {use_dp}") - # print(f"output:{output.shape}") - # print(f"fused_expert_output:{fused_expert_output.shape}") output.copy_(fused_expert_output) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 2d4f012090d2..294d87840880 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -531,12 +531,9 @@ def forward( if num_chunks == 1: fused_out = _resize_cache(workspace13, fused_out_shape) - # print(f"fused_out_shape:{fused_out_shape}") - # print(f"a1q:{a1q.shape}") if 'topk_weights' in expert_kwargs and expert_kwargs['topk_weights'] is None: expert_kwargs['topk_weights'] = topk_weights assert expert_kwargs['topk_weights'] is not None - # print("replacing topk_weights") self.fused_experts.apply( fused_out, a1q, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index f1d6c6bfd047..ec055a76c18c 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -98,7 +98,6 @@ def moe_kernel_quantize_input( elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.uint8: # nvfp4 - # print(f"calling fp4 quantize:{is_sf_swizzled_layout}") return _fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scalar_swizzled) else: return A, A_scale diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index eeeee43836c3..b6f54126ff3e 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -554,7 +554,6 @@ def select_gemm_impl(self, prepare_finalize, moe): assert moe is not None assert prepare_finalize is not None experts = None - # print("fffff"*100) all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None if self.allow_flashinfer_cutlass: @@ -845,33 +844,7 @@ def apply( extra_finalize_args = { 'use_dp': layer.dp_size > 1, } - # from flashinfer import fp4_quantize as fp4_quantize - # from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe - # quant_scales = [ - # torch.min(layer.w13_input_scale_quant), - # layer.w13_blockscale_swizzled.view(torch.int32), - # layer.g1_alphas, - # torch.min(layer.w2_input_scale_quant), - # layer.w2_blockscale_swizzled.view(torch.int32), - # layer.g2_alphas, - # ] - # xq, input_sf = fp4_quantize(x, torch.min(layer.w13_input_scale_quant)) - # out = flashinfer_cutlass_fused_moe( - # xq, - # topk_ids, - # topk_weights, - # layer.w13_weight.view(torch.long), - # layer.w2_weight.view(torch.long), - # x.dtype, - # quant_scales, - # input_sf, - # self.fused_experts.fused_experts.tp_size, - # self.fused_experts.fused_experts.tp_rank, - # self.fused_experts.fused_experts.ep_size, - # self.fused_experts.fused_experts.ep_rank, - # )[0] - - # print(f"usedp:{layer.dp_size > 1}") + out = self.fused_experts( hidden_states=x, w1=layer.w13_weight, From f4dc86bf4292d662a00e7061d95b912b05ba122d Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 7 Jul 2025 19:41:13 +0000 Subject: [PATCH 06/30] Fix NoEP class --- .../fused_moe/flashinfer_cutlass_moe.py | 32 ++++++++----------- .../flashinfer_cutlass_prepare_finalize.py | 8 +++-- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 688849c1406d..b8627b3b3f5c 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -109,20 +109,17 @@ def workspace_shapes( - Workspace type: The dtype to use for the workspace tensors. - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens. - """ - # num_experts = global_num_experts - # block_m = self.block_shape[0] - # M_sum = (M * topk) + num_experts * (block_m - 1) - # M_sum = round_up(M_sum, block_m) - # workspace1 = () - # TODO(shuw): This is nvfp4 specialized, add branch for other quant type. - aq_m, aq_n = aq.shape - workspace2 = () - output_shape = (aq_m, aq_n * 2) - workspace_dtype = a.dtype - workspace1 = output_shape - # determined by aq, since aq is the one after possible communication op and participate in experts computation. - return (workspace1, workspace2, output_shape, workspace_dtype) + """ + if self.use_nvfp4_w4a4: + aq_m, aq_n = aq.shape + workspace2 = () + output_shape = (aq_m, aq_n * 2) + workspace_dtype = a.dtype + workspace1 = output_shape + # determined by aq, since aq is the one after possible communication op and participate in experts computation. + return (workspace1, workspace2, output_shape, workspace_dtype) + else: + raise ValueError("Only nvfp4 quantization is currently supported.") def apply( self, @@ -160,7 +157,7 @@ def apply( w2_scale.view(torch.int32), g2_alphas, ] - out = flashinfer_cutlass_fused_moe( + _ = flashinfer_cutlass_fused_moe( hidden_states, topk_ids.to(torch.int), topk_weights, @@ -174,8 +171,7 @@ def apply( tp_rank=self.tp_rank, ep_size=self.ep_size, ep_rank=self.ep_rank, - )[0] - # TODO(shuw): Handle the allocation from FlashInfer to framework - output.copy_(out) + output=output, + ) else: raise ValueError("Only nvfp4 quantization is currently supported.") diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index b6d13d03c275..2aaee3a4d772 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -21,10 +21,8 @@ def get_local_sizes(): sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item()) return sizes -#should be mk.FusedMoEPrepareAndFinalize chk pplx -class FlashInferCutlassMoEPrepareAndFinalize(MoEPrepareAndFinalizeNoEP - ): +class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__( self, quant_dtype: Optional[torch.dtype] = None, @@ -36,6 +34,10 @@ def __init__( self.block_shape = block_shape self.quant_dtype = quant_dtype + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + def max_num_tokens_per_rank(self) -> Optional[int]: return None From 7df49e19db9c1ed5f52816424a48c0144054e9ca Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 7 Jul 2025 21:44:48 +0000 Subject: [PATCH 07/30] Fix NoEP class --- vllm/model_executor/layers/quantization/modelopt.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index b6f54126ff3e..5c246b9262fb 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -513,8 +513,11 @@ def load_up_proj_weight_first(self) -> bool: # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 return self.allow_flashinfer_cutlass - def select_experts_impl(self, moe_parallel_config): + def select_experts_impl(self, moe_parallel_config): if not self.allow_flashinfer_cutlass: + # if moe_parallel_config.dp_size > 1: + # raise ValueError("CutlassExpertsFp4 Doesn't support DP. " + # "Use flashinfer CUTLASS FusedMoE backend instead.") return logger.debug("FlashInferExperts") From 5af3db95fb206f5bb7d3139e04ae796696eac0be Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 9 Jul 2025 04:29:41 +0000 Subject: [PATCH 08/30] chunking work --- tests/distributed/test_pynccl.py | 72 ++++++++++++++++++- .../base_device_communicator.py | 32 ++++----- .../fused_moe/flashinfer_cutlass_moe.py | 2 +- .../flashinfer_cutlass_prepare_finalize.py | 18 +++-- vllm/model_executor/layers/fused_moe/layer.py | 17 +++-- .../layers/quantization/modelopt.py | 2 + vllm/model_executor/models/deepseek_v2.py | 1 + 7 files changed, 115 insertions(+), 29 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 5b32b90f3cfe..b624599fd1e4 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -4,6 +4,7 @@ import multiprocessing import os +import numpy as np import pytest import torch import torch.distributed @@ -177,6 +178,38 @@ def test_pynccl_all_gather(): distributed_run(all_gather_worker_fn, 2) +@worker_fn_wrapper +def all_gatherv_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + assert world_size <= 8 + sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] + num_elems = sizes[rank] + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * 100 + result = torch.zeros(sum(sizes), dtype=torch.float32, device=device) + + expected = torch.cat([ + torch.arange(sizes[r], dtype=torch.float32) + r * 100 + for r in range(world_size) + ]).to(device) + + pynccl_comm.all_gather(result, tensor, sizes=sizes) + torch.cuda.synchronize() + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_all_gatherv(): + distributed_run(all_gatherv_worker_fn, 2) + + @worker_fn_wrapper def reduce_scatter_worker_fn(): pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, @@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter(): distributed_run(reduce_scatter_worker_fn, 2) +@worker_fn_wrapper +def reduce_scatterv_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + assert world_size <= 8 + sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] + num_elems = sum(sizes) + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * 100 + result = torch.zeros(sizes[rank], dtype=torch.float32, device=device) + + # Calculate expected result for this rank's chunk + all_tensors = [ + torch.arange(num_elems, dtype=torch.float32) + r * 100 + for r in range(world_size) + ] + sizes_cumsum = np.cumsum(sizes) + start = 0 if rank == 0 else sizes_cumsum[rank - 1] + end = sizes_cumsum[rank] + expected = sum(tensor[start:end] for tensor in all_tensors).to(device) + + pynccl_comm.reduce_scatter(result, tensor, sizes=sizes) + torch.cuda.synchronize() + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_reduce_scatterv(): + distributed_run(reduce_scatterv_worker_fn, 2) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") def test_pynccl_with_cudagraph(): @@ -329,4 +399,4 @@ def test_ncclGetUniqueId(): # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] # as long as the function doesn't raise an exception, we're good - assert unique_id is not None + assert unique_id is not None \ No newline at end of file diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 69e6f405fdde..212a5ad3bba4 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading -from typing import List, Optional, Union +from typing import Optional, Union from weakref import WeakValueDictionary import torch @@ -138,23 +138,17 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: input_size[dim + 1:]) return output_tensor - def all_gatherv(self, - input_: Union[torch.Tensor, List[torch.Tensor]], - dim: int = 0, - sizes: Optional[List[int]] = None): - assert False, "not implemented" - - def all_gatherv(self, - input_: Union[torch.Tensor, List[torch.Tensor]], - dim: int = 0, - sizes: Optional[List[int]] = None): - assert False, "not implemented" + def all_gatherv( + self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None + ) -> Union[torch.Tensor, list[torch.Tensor]]: + raise NotImplementedError def reduce_scatter(self, input_: torch.Tensor, - dim: int = -1, - sizes: Optional[List[int]] = None) -> torch.Tensor: - assert sizes is None, "Varying size reduce scatter not supported with base device communicator" + dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: @@ -186,6 +180,12 @@ def reduce_scatter(self, # Reshape before returning return output_tensor.movedim(0, dim).contiguous() + def reduce_scatterv(self, + input_: torch.Tensor, + dim: int = -1, + sizes: Optional[list[int]] = None) -> torch.Tensor: + raise NotImplementedError + def gather(self, input_: torch.Tensor, dst: int = 0, @@ -271,4 +271,4 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: Combine the hidden states and router logits from the appropriate device. This is a no-op in the base class. """ - return hidden_states + return hidden_states \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index b8627b3b3f5c..b0381198113e 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -85,7 +85,7 @@ def supports_expert_map(self) -> bool: return False def supports_chunking(self) -> bool: - #TODO(shuw): support chunking later + #TODO(shuw): support chunking later, actually support in layer.py return False def workspace_shapes( diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 2aaee3a4d772..0f9baf27e105 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional +import vllm.envs as envs import torch from flashinfer import fp4_swizzle_blockscale @@ -14,13 +15,20 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) -def get_local_sizes(): +def get_local_sizes(local_tokens): cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu sizes = [cu_sizes[0].item()] for i in range(1, len(cu_sizes)): sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item()) - return sizes + max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE + sizes_chunked = [max_num_tokens] * len(sizes) + if local_tokens < max_num_tokens: + # When the number of local tokens is less than max_num_tokens, all other + # ranks will also have fewer than max_num_tokens. The remaining tokens + # are accounted for as residual. + sizes_chunked = [x % max_num_tokens for x in sizes] + return sizes_chunked class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__( @@ -56,6 +64,7 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, use_dp: Optional[bool] = True, + local_tokens: int = -1, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -79,7 +88,7 @@ def prepare( topk_weights, topk_ids, a1q, a1q_scale = \ get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], dim=0, - sizes=get_local_sizes()) + sizes=get_local_sizes(local_tokens)) a1_m, a1_n = a1q.shape a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2) @@ -93,11 +102,12 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, use_dp: bool = False, + local_tokens: int = -1, ) -> None: if use_dp: fused_expert_output = get_dp_group().reduce_scatter( fused_expert_output, dim=0, - sizes=get_local_sizes(), + sizes=get_local_sizes(local_tokens), ) output.copy_(fused_expert_output) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ea16807e1b08..f3485d800615 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -805,7 +805,8 @@ def __init__( self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels): + or self.moe_parallel_config.use_deepep_ll_kernels + or self.moe_parallel_config.use_flashinfer_cutlass_kernels): self.batched_hidden_states = torch.zeros( (moe.max_num_tokens, self.hidden_size), dtype=moe.in_dtype, @@ -1414,6 +1415,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): expert_load_view=self.expert_load_view, logical_to_physical_map=self.logical_to_physical_map, logical_replica_count=self.logical_replica_count, + # start ) if not skip_result_store: @@ -1422,11 +1424,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): ctx = get_forward_context() #TODO(shuw):where is it? - # flashinfer_cutlass_kernels can handle TP+EP without DP - max_tokens_across_dp = (MOE_DP_CHUNK_SIZE if self.dp_size == 1 else - ctx.dp_metadata.max_tokens_across_dp_cpu) + # flashinfer_cutlass_kernels can handle: optional DP + TP/EP + max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens - num_tokens = full_hidden_states.size(0) for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): @@ -1446,15 +1446,18 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + # Route to the chunked forward path using the FlashInfer Cutlass kernel + # only when data parallelism (DP) is enabled. + use_flashinfer_cutlass_kernels = self.dp_size > 1 and self.moe_parallel_config.use_flashinfer_cutlass_kernels if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels): + or self.moe_parallel_config.use_deepep_ll_kernels or + use_flashinfer_cutlass_kernels): return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_parallel_config.use_flashinfer_cutlass_kernels) - # do_naive_dispatch_combine = True if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 5c246b9262fb..e7e8fafbab70 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -843,9 +843,11 @@ def apply( } extra_prepare_args = { 'use_dp': layer.dp_size > 1, + 'local_tokens': x.shape[0], } extra_finalize_args = { 'use_dp': layer.dp_size > 1, + 'local_tokens': x.shape[0], } out = self.fused_experts( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 2fa1294b79b9..b8f5ffb557a7 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -202,6 +202,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: * (1. / self.routed_scaling_factor) if self.tp_size > 1: + # print(f"final_hidden_states:{final_hidden_states.shape}") final_hidden_states = ( self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states)) From 4c5fa6dc31d1b574fc7ad31b335b49186bc114d3 Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 9 Jul 2025 05:15:08 +0000 Subject: [PATCH 09/30] Address comments --- .../device_communicators/cuda_communicator.py | 93 +++++++------ .../device_communicators/pynccl.py | 128 +++++++++++------- .../device_communicators/pynccl_wrapper.py | 3 +- vllm/distributed/parallel_state.py | 6 + .../model_executor/layers/fused_moe/config.py | 6 - .../layers/fused_moe/cutlass_moe.py | 5 +- .../fused_moe/flashinfer_cutlass_moe.py | 81 ++++++----- .../flashinfer_cutlass_prepare_finalize.py | 4 +- vllm/model_executor/layers/fused_moe/layer.py | 1 + .../layers/quantization/modelopt.py | 6 +- 10 files changed, 185 insertions(+), 148 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 727c64518837..72639996edef 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,14 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import List, Optional, Union +from typing import Optional, Union import torch from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.logger import init_logger -from vllm.platforms import current_platform from .base_device_communicator import DeviceCommunicatorBase @@ -42,8 +41,6 @@ def __init__(self, CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) - from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickAllReduce) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -53,7 +50,6 @@ def __init__(self, ) self.ca_comm: Optional[CustomAllreduce] = None - self.qr_comm: Optional[QuickAllReduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -61,14 +57,6 @@ def __init__(self, device=self.device, ) - if current_platform.is_rocm(): - # Initialize a custom quick all-reduce implementation for AMD. - # Quick reduce is designed as a complement to custom allreduce. - # Based on quickreduce (https://github.com/mk1-project/quickreduce). - # If it's a rocm, 'use_custom_allreduce==True' means it must - # currently be an MI300 series. - self.qr_comm = QuickAllReduce(group=self.cpu_group, - device=self.device) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -91,14 +79,8 @@ def __init__(self, raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): - # always try quick reduce first, then custom allreduce, - # and then pynccl. (quick reduce just for ROCM MI3*) - qr_comm = self.qr_comm - if qr_comm is not None and not qr_comm.disabled and \ - qr_comm.should_quick_allreduce(input_): - out = qr_comm.quick_all_reduce(input_) - assert out is not None - return out + # always try custom allreduce first, + # and then pynccl. ca_comm = self.ca_comm if ca_comm is not None and not ca_comm.disabled and \ ca_comm.should_custom_ar(input_): @@ -117,10 +99,35 @@ def all_reduce(self, input_): torch.distributed.all_reduce(out, group=self.device_group) return out - def reduce_scatter(self, - input_: torch.Tensor, - dim: int = -1, - sizes: Optional[List[int]] = None): + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + pynccl_comm.reduce_scatter(output, input_) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + + def reduce_scatterv(self, + input_: torch.Tensor, + dim: int = -1, + sizes: Optional[list[int]] = None): world_size = self.world_size pynccl_comm = self.pynccl_comm assert pynccl_comm is not None @@ -145,7 +152,10 @@ def reduce_scatter(self, dtype=input_tensor.dtype, device=input_tensor.device) - pynccl_comm.reduce_scatter(output, input_, sizes=sizes) + if sizes is not None: + pynccl_comm.reduce_scatterv(output, input_, sizes=sizes) + else: + pynccl_comm.reduce_scatter(output, input_) # Reshape before returning return output.movedim(0, dim).contiguous() @@ -188,32 +198,25 @@ def destroy(self): self.all2all_manager.destroy() self.all2all_manager = None - """ - Allgather with support for list of tensors and varying sizes per rank. - Example: - Instead of: - ... = get_ep_group().dispatch(...) - Use this: - ... = get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], dim=0, sizes=get_forward_context().dp_metadata.num_tokens_across_dp_cpu) - """ - def all_gatherv(self, - input_: Union[torch.Tensor, List[torch.Tensor]], + input_: Union[torch.Tensor, list[torch.Tensor]], dim: int = 0, - sizes: Optional[List[int]] = None): - assert dim == 0, "only dim 0 all-gather is supported" + sizes: Optional[list[int]] = None): + if dim != 0: + raise NotImplementedError("only dim 0 all-gatherv is supported") world_size = self.world_size pynccl_comm = self.pynccl_comm assert pynccl_comm is not None and not pynccl_comm.disabled def _all_gather_single(input_: torch.Tensor, - sizes: Optional[List[int]] = None): + sizes: Optional[list[int]] = None): input_size = input_.size() if sizes is not None: assert len(sizes) == world_size assert input_.shape[dim] == sizes[self.rank_in_group] output_size = (sum(sizes), ) + input_size[1:] - # 'sizes' is not needed if all inputs in the same group have the same shape + # 'sizes' is not needed if all inputs in the same group have the + # same shape if all(s == sizes[0] for s in sizes): sizes = None else: @@ -222,17 +225,21 @@ def _all_gather_single(input_: torch.Tensor, output_tensor = torch.empty(output_size, dtype=input_.dtype, device=input_.device) - pynccl_comm.all_gather(output_tensor, input_, sizes=sizes) + if sizes is not None: + pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes) + else: + pynccl_comm.all_gather(output_tensor, input_) return output_tensor if isinstance(input_, torch.Tensor): return _all_gather_single(input_, sizes) - pynccl_comm.group_start() output_list = [] + pynccl_comm.group_start() for inp in input_: output_list.append(_all_gather_single(inp, sizes=sizes)) pynccl_comm.group_end() + return output_list def dispatch( @@ -246,4 +253,4 @@ def dispatch( def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: assert self.all2all_manager is not None hidden_states = self.all2all_manager.combine(hidden_states) - return hidden_states + return hidden_states \ No newline at end of file diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 5829e4e460cf..9f20bb1ef289 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import List, Optional, Union +from typing import Optional, Union -import numpy as np # ===================== import region ===================== import torch import torch.distributed as dist @@ -136,8 +135,7 @@ def all_reduce(self, def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, - stream=None, - sizes: Optional[List[int]] = None): + stream=None): if self.disabled: return # nccl communicator created on a specific device @@ -148,38 +146,51 @@ def all_gather(self, f"but the input tensor is on {input_tensor.device}") if stream is None: stream = current_stream() - if sizes is not None: - assert output_tensor.shape[0] == sum(sizes) - numel_base = int(np.prod(output_tensor.shape[1:])) - split_offset = 0 - self.nccl.ncclGroupStart() - for root, split_size in enumerate(sizes): - dst_slice = output_tensor[split_offset:split_offset + - split_size] - self.nccl.ncclBroadcast( - buffer_type(input_tensor.data_ptr()), - buffer_type(dst_slice.data_ptr()), - split_size * numel_base, - ncclDataTypeEnum.from_torch(input_tensor.dtype), - root, - self.comm, - cudaStream_t(stream.cuda_stream), - ) - split_offset += split_size - self.nccl.ncclGroupEnd() - else: - self.nccl.ncclAllGather( + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + cudaStream_t(stream.cuda_stream)) + + def all_gatherv( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + sizes: list[int], + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + assert output_tensor.shape[0] == sum(sizes) + split_offset = 0 + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + dst_slice = output_tensor[split_offset:split_offset + split_size] + self.nccl.ncclBroadcast( buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), input_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, - cudaStream_t(stream.cuda_stream)) + buffer_type(dst_slice.data_ptr()), + dst_slice.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + root, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + split_offset += split_size + self.nccl.ncclGroupEnd() def reduce_scatter(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, - stream=None, - sizes: Optional[List[int]] = None): + stream=None): if self.disabled: return # nccl communicator created on a specific device @@ -190,29 +201,44 @@ def reduce_scatter(self, f"but the input tensor is on {input_tensor.device}") if stream is None: stream = current_stream() + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) - if sizes is not None: - numel_base = int(np.prod(input_tensor.shape[1:])) - split_offset = 0 - self.nccl.ncclGroupStart() - for root, split_size in enumerate(sizes): - chunk = input_tensor[split_offset:split_offset + split_size, :] - self.nccl.ncclReduce( - buffer_type(chunk.data_ptr()), - buffer_type(output_tensor.data_ptr()), - split_size * numel_base, - ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), root, self.comm, - cudaStream_t(stream.cuda_stream)) - split_offset += split_size - self.nccl.ncclGroupEnd() - else: - self.nccl.ncclReduceScatter( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + def reduce_scatterv( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + sizes: list[int], + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + + split_offset = 0 + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + chunk = input_tensor[split_offset:split_offset + split_size, ...] + self.nccl.ncclReduce( + buffer_type(chunk.data_ptr()), + buffer_type(output_tensor.data_ptr()), chunk.numel(), ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, + ncclRedOpTypeEnum.from_torch(op), root, self.comm, cudaStream_t(stream.cuda_stream)) + split_offset += split_size + self.nccl.ncclGroupEnd() def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: @@ -261,4 +287,4 @@ def group_start(self): self.nccl.ncclGroupStart() def group_end(self): - self.nccl.ncclGroupEnd() + self.nccl.ncclGroupEnd() \ No newline at end of file diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index c8e772447ee9..8efda853c05a 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -164,6 +164,7 @@ class NCCLLibrary: buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclAllGather( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclComm_t comm, @@ -378,4 +379,4 @@ def ncclGroupEnd(self) -> None: __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", "ncclComm_t", "cudaStream_t", "buffer_type" -] +] \ No newline at end of file diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 964d13b61699..16e5ce986267 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -408,6 +408,12 @@ def reduce_scatter(self, else: return self._reduce_scatter_out_place(input_, dim, sizes) + def reduce_scatterv(self, + input_: torch.Tensor, + dim: int = -1, + sizes: Optional[list[int]] = None) -> torch.Tensor: + return self.device_communicator.reduce_scatterv(input_, dim, sizes) + def _reduce_scatter_out_place( self, input_: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 02d95cc2210e..9c093683dae4 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -128,16 +128,10 @@ class FusedMoEParallelConfig: def use_all2all_kernels(self): return self.dp_size > 1 and self.use_ep - @property - def use_pplx_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "pplx") - @property def use_pplx_kernels(self): return (self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "pplx") - @property def use_deepep_ht_kernels(self): return (self.use_all2all_kernels diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 8138a37de128..e88b5c28dd42 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -414,7 +414,7 @@ def run_cutlass_moe_fp4( k: int, e: int, device: torch.device, -): +) -> None: """ MoE implementation for FP4 Inputs @@ -572,6 +572,9 @@ def workspace_shapes( ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Workspace1: for c1, Workspace2: for intermediate, Output: for final output workspace1 = a.shape + # workspace 2 remains empty because run_cutlass_moe_fp4 allocates the + # intermediate tensor there. + # # workspace 1 is allocated to store the output. workspace2 = () output = a.shape return (workspace1, workspace2, output, self.out_dtype) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index b0381198113e..896d4201bf58 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -25,13 +25,12 @@ has_flashinfer_cutlass_fused_moe = flashinfer_cutlass_fused_moe is not None -#TODO(shuw): use this check + def _valid_flashinfer_fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> bool: """ - Check if the given problem size is supported by the DeepGemm grouped - gemm kernel. All of M, N, K and the quantization block_shape must be - aligned by `dg.get_m_alignment_for_contiguous_layout()`. + Check if the given problem size is supported by the FlashInfer CUTLASS MoE + kernel. """ if not has_flashinfer_cutlass_fused_moe: logger.debug( @@ -110,16 +109,17 @@ def workspace_shapes( - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens. """ - if self.use_nvfp4_w4a4: - aq_m, aq_n = aq.shape - workspace2 = () - output_shape = (aq_m, aq_n * 2) - workspace_dtype = a.dtype - workspace1 = output_shape - # determined by aq, since aq is the one after possible communication op and participate in experts computation. - return (workspace1, workspace2, output_shape, workspace_dtype) - else: - raise ValueError("Only nvfp4 quantization is currently supported.") + assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " + "currently supported.") + aq_m, aq_n = aq.shape + workspace2 = () + output_shape = (aq_m, aq_n * 2) + workspace_dtype = a.dtype + workspace1 = output_shape + # The workspace is determined by `aq`, since it comes after any + # potential communication op and is involved in the expert computation. + return (workspace1, workspace2, output_shape, workspace_dtype) + def apply( self, @@ -148,30 +148,29 @@ def apply( ): # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. - if self.use_nvfp4_w4a4: - quant_scales = [ - a1_scale, - w1_scale.view(torch.int32), - g1_alphas, - a2_scale, - w2_scale.view(torch.int32), - g2_alphas, - ] - _ = flashinfer_cutlass_fused_moe( - hidden_states, - topk_ids.to(torch.int), - topk_weights, - # FlashInfer API requires weight to be long for nvfp4 - w1.view(torch.long), - w2.view(torch.long), - output_dtype=out_dtype, - quant_scales=quant_scales, - input_sf=a1q_scale, - tp_size=self.tp_size, - tp_rank=self.tp_rank, - ep_size=self.ep_size, - ep_rank=self.ep_rank, - output=output, - ) - else: - raise ValueError("Only nvfp4 quantization is currently supported.") + assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " + "currently supported.") + quant_scales = [ + a1_scale, + w1_scale.view(torch.int32), + g1_alphas, + a2_scale, + w2_scale.view(torch.int32), + g2_alphas, + ] + _ = flashinfer_cutlass_fused_moe( + hidden_states, + topk_ids.to(torch.int), + topk_weights, + # FlashInfer API requires weight to be long for nvfp4 + w1.view(torch.long), + w2.view(torch.long), + output_dtype=out_dtype, + quant_scales=quant_scales, + input_sf=a1q_scale, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + output=output, + ) \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 0f9baf27e105..7ac4d08eab39 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -4,7 +4,6 @@ import vllm.envs as envs import torch -from flashinfer import fp4_swizzle_blockscale from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -90,6 +89,7 @@ def prepare( dim=0, sizes=get_local_sizes(local_tokens)) a1_m, a1_n = a1q.shape + from flashinfer import fp4_swizzle_blockscale a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2) return a1q, a1q_scale, None, topk_ids, topk_weights @@ -105,7 +105,7 @@ def finalize( local_tokens: int = -1, ) -> None: if use_dp: - fused_expert_output = get_dp_group().reduce_scatter( + fused_expert_output = get_dp_group().reduce_scatterv( fused_expert_output, dim=0, sizes=get_local_sizes(local_tokens), diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f3485d800615..379310cfb1e4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1448,6 +1448,7 @@ def forward_impl(self, hidden_states: torch.Tensor, assert self.quant_method is not None # Route to the chunked forward path using the FlashInfer Cutlass kernel # only when data parallelism (DP) is enabled. + # TODO(shuw): Make TP calling also chunked. use_flashinfer_cutlass_kernels = self.dp_size > 1 and self.moe_parallel_config.use_flashinfer_cutlass_kernels if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels or diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e7e8fafbab70..fb631fb1e2b2 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -33,7 +33,7 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts) + FlashInferExperts, _valid_flashinfer_fused_moe) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( FlashInferCutlassMoEPrepareAndFinalize) import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -832,8 +832,8 @@ def apply( if self.allow_flashinfer_cutlass: # TP or DP case - # import pdb - # pdb.set_trace() + assert _valid_flashinfer_fused_moe( + x, layer.w13_weight, layer.w2_weight), ("Flashinfer CUTLASS Fused MoE not applicable!") extra_expert_args = { 'topk_weights': None, #placeholder topk_weights, 'g1_alphas': layer.g1_alphas, From f99cf65eb5d6fdd265812f43de98e4a32c718b07 Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 9 Jul 2025 15:32:45 +0000 Subject: [PATCH 10/30] minor fix --- .../device_communicators/base_device_communicator.py | 2 +- vllm/distributed/device_communicators/pynccl.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 1 - vllm/model_executor/models/deepseek_v2.py | 1 - 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 212a5ad3bba4..6e6ccd740c6d 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -271,4 +271,4 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: Combine the hidden states and router logits from the appropriate device. This is a no-op in the base class. """ - return hidden_states \ No newline at end of file + return hidden_states diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 9f20bb1ef289..502bfd39005a 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -287,4 +287,4 @@ def group_start(self): self.nccl.ncclGroupStart() def group_end(self): - self.nccl.ncclGroupEnd() \ No newline at end of file + self.nccl.ncclGroupEnd() diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 379310cfb1e4..373ba48e6182 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1415,7 +1415,6 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): expert_load_view=self.expert_load_view, logical_to_physical_map=self.logical_to_physical_map, logical_replica_count=self.logical_replica_count, - # start ) if not skip_result_store: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b8f5ffb557a7..2fa1294b79b9 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -202,7 +202,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: * (1. / self.routed_scaling_factor) if self.tp_size > 1: - # print(f"final_hidden_states:{final_hidden_states.shape}") final_hidden_states = ( self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states)) From dbefd52cf47d76a03e7af07fe164fa5eb48db2e2 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 10 Jul 2025 03:59:11 +0000 Subject: [PATCH 11/30] cutlass_moe_fp4 support TP chunking --- vllm/_custom_ops.py | 10 +- .../device_communicators/all2all.py | 3 +- .../layers/fused_moe/cutlass_moe.py | 135 +++++++++++++----- .../layers/fused_moe/modular_kernel.py | 20 +++ .../layers/quantization/modelopt.py | 6 +- 5 files changed, 128 insertions(+), 46 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6b1b3f787c23..9b520c71558f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -942,7 +942,8 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, c_strides, per_act_token, per_out_ch) -def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, +def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, + b_tensors: torch.Tensor, a_scales: torch.Tensor, b_scales: torch.Tensor, alphas: torch.Tensor, problem_sizes: torch.Tensor, expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, @@ -963,14 +964,9 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. """ - m_topk = a_tensors.shape[0] - n = b_tensors.shape[1] - c_shape = (m_topk, n) - c = torch.empty(c_shape, device=device, dtype=out_dtype) - torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales, + return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors, a_scales, b_scales, alphas, problem_sizes, expert_offsets, sf_offsets) - return c.to(out_dtype) # aqlm diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 85f87cb21edc..b7beb70173dd 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -65,7 +65,8 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1] end = cu_tokens_across_dp_cpu[self.dp_rank] - + print(f"dp_ws:{self.dp_group.world_size}; ws:{self.world_size}") + print(f"dp_rank:{self.dp_rank}; rank: {self.rank}") all_hidden_states = self.dp_group.all_reduce(hidden_states) hidden_states = all_hidden_states[start:end, :] return hidden_states diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index e88b5c28dd42..d3488c435748 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -408,13 +408,30 @@ def run_cutlass_moe_fp4( w2_blockscale: torch.Tensor, w2_alphas: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + topk_ids: torch.Tensor, + workspace13: torch.Tensor, + workspace2: torch.Tensor, m: int, n: int, k: int, e: int, device: torch.device, ) -> None: + # print("---------------start-------------") + # print(f"output.shape: {output.shape}") # [m, n] + # print(f"a.shape: {a.shape}") # [m, k] + # print(f"a1_gscale.shape: {a1_gscale.shape}") # [1] + # print(f"w1_fp4.shape: {w1_fp4.shape}") # [e, k//2, n] + # print(f"w1_blockscale.shape: {w1_blockscale.shape}") # [e, n//groupsize] + # print(f"w1_alphas.shape: {w1_alphas.shape}") # [e, 1] + # print(f"a2_gscale.shape: {a2_gscale.shape}") # [1] + # print(f"w2_fp4.shape: {w2_fp4.shape}") # [e, n//2, k] + # print(f"w2_blockscale.shape: {w2_blockscale.shape}") # [e, k//groupsize] + # print(f"w2_alphas.shape: {w2_alphas.shape}") # [e, 1] + # print(f"topk_weights.shape: {topk_weights.shape}") # [m, top_k] + # print(f"topk_ids.shape: {topk_ids.shape}") # [m, top_k] + # print(f"m:{m}, n:{n}, k:{k}, e:{e}") + """ MoE implementation for FP4 Inputs @@ -454,16 +471,18 @@ def run_cutlass_moe_fp4( assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", " between weights.") - assert (k_a // 2 == half_k_w1 + assert (k_a == half_k_w1 * 2 and k == k_w2), ("Hidden size mismatch between a, w1 and w2") - assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in " + assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in " "expected `n`") + # print(f"m:{m}; m_a:{m_a}") + # m = m_a assert (m == m_a), "input shape mismatch" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" assert (topk_weights.size(0) == m and topk_ids.size(0) == m), ("topk must be provided for each row of a") - + topk = topk_ids.size(1) out_dtype = a.dtype num_topk = topk_ids.size(1) @@ -484,7 +503,7 @@ def run_cutlass_moe_fp4( blockscale_offsets) a = ops.shuffle_rows(a, a_map) - + # print(f"a:{a.shape}; a1_gscale:{a1_gscale.shape}") rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( a, a1_gscale, @@ -492,31 +511,45 @@ def run_cutlass_moe_fp4( blockscale_offsets, num_topk, ) - - c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, + c1 = _resize_cache(workspace13, (m * topk, n * 2)) + c2 = _resize_cache(workspace2, (m * topk, n)) + c3 = _resize_cache(workspace13, (m * topk, k)) + ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale, w1_blockscale, w1_alphas, problem_sizes1, expert_offsets[:-1], blockscale_offsets[:-1], out_dtype, device) + # print(f"c1:{c1.shape}") del rep_a_fp4, rep_a_blockscale # hidden size dimension is split to one halfpytho sized tensor. - intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), - device=device, - dtype=out_dtype) - - torch.ops._C.silu_and_mul(intermediate, c1) - + # intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), + # device=device, + # dtype=out_dtype) + # print(f"intermediate:{intermediate.shape}") + torch.ops._C.silu_and_mul(c2, c1) + # torch.ops._C.silu_and_mul(intermediate, c1) + # print(f"c2:{c2.shape}") + # print(f"a2_gscale:{a2_gscale.shape}") + # import pdb + # pdb.set_trace() int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( - intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk) + c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk) - c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, + ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale, w2_alphas, problem_sizes2, expert_offsets[:-1], blockscale_offsets[:-1], out_dtype, device) + # print(f"c2:{c2.shape}") del int_fp4, int_blockscale - c2 = ops.shuffle_rows(c2, c_map) + c3 = ops.shuffle_rows(c3, c_map) + assert output.dtype == out_dtype - output.copy_((c2.view(m, num_topk, k) * - topk_weights.view(m, num_topk, 1).half()).sum(dim=1)) + # print(f"output:{output.shape}") + # print(f"c3:{c3.view(m, num_topk, k).shape}") + # print(f"b:{topk_weights.view(m, num_topk, 1).half().shape}") + # print(f"g:{g.shape}") + output.copy_((c3.view(m, num_topk, k) * + topk_weights.view(m, num_topk, 1).half()).sum(dim=1), non_blocking=True) + # print("---------------end-------------") return @@ -557,8 +590,28 @@ def supports_expert_map(self) -> bool: return False def supports_chunking(self) -> bool: - return False - + return True + + # def workspace_shapes( + # self, + # a: torch.Tensor, + # aq: torch.Tensor, + # M: int, + # N: int, + # K: int, + # topk: int, + # global_num_experts: int, + # local_num_experts: int, + # ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # # Workspace1: for c1, Workspace2: for intermediate, Output: for final output + # workspace1 = a.shape + # # workspace 2 remains empty because run_cutlass_moe_fp4 allocates the + # # intermediate tensor there. + # # # workspace 1 is allocated to store the output. + # workspace2 = () + # output = a.shape + # return (workspace1, workspace2, output, self.out_dtype) + def workspace_shapes( self, a: torch.Tensor, @@ -570,14 +623,20 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - # Workspace1: for c1, Workspace2: for intermediate, Output: for final output - workspace1 = a.shape - # workspace 2 remains empty because run_cutlass_moe_fp4 allocates the - # intermediate tensor there. - # # workspace 1 is allocated to store the output. - workspace2 = () - output = a.shape - return (workspace1, workspace2, output, self.out_dtype) + workspace1: tuple[int, ...] = () + workspace2: tuple[int, ...] = () + output: tuple[int, ...] = () + if self.use_batched_format: + padded_M = aq.size(1) + workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) + workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) + output = (self.max_experts_per_worker, padded_M, K) + else: + workspace1 = (M * topk, max(2 * N, K)) + workspace2 = (M * topk, N) + output = (M, K) + return (workspace1, workspace2, output, + self.out_dtype if self.out_dtype is not None else a.dtype) def apply( self, @@ -602,7 +661,8 @@ def apply( topk_weights: torch.Tensor, g1_alphas: torch.Tensor, g2_alphas: torch.Tensor, - a1_scale: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, m: int, n: int, k: int, @@ -615,16 +675,18 @@ def apply( run_cutlass_moe_fp4( output, hidden_states, - a1_scale, + a1_gscale, w1, w1_scale, g1_alphas, - a2_scale, + a2_gscale, w2, w2_scale, g2_alphas, topk_weights, topk_ids, + workspace13, + workspace2, m, n, k, @@ -641,8 +703,8 @@ def cutlass_moe_fp4( w2_blockscale: torch.Tensor, g1_alphas: torch.Tensor, g2_alphas: torch.Tensor, - a1_scale: torch.Tensor, - a2_scale: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, @@ -669,7 +731,8 @@ def cutlass_moe_fp4( 'topk_weights': topk_weights, 'g1_alphas': g1_alphas, 'g2_alphas': g2_alphas, - 'a1_scale': a1_scale, + 'a1_gscale': a1_gscale, + 'a2_gscale': a2_gscale, 'm': m, 'n': n, 'k': k, @@ -694,8 +757,8 @@ def cutlass_moe_fp4( expert_map=None, w1_scale=w1_blockscale, w2_scale=w2_blockscale, - a1_scale=a1_scale, - a2_scale=a2_scale, + a1_scale=None, + a2_scale=None, apply_router_weight_on_input=False, extra_expert_args=extra_expert_args, extra_prepare_args=extra_prepare_args, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 294d87840880..4d7605a6837e 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -498,9 +498,11 @@ def forward( if self.fused_experts.enable_chunking(): CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE num_chunks = cdiv(M, CHUNK_SIZE) + # print(f"1CHUNK_SIZE:{CHUNK_SIZE} and M:{M}") else: CHUNK_SIZE = M num_chunks = 1 + # print(f"2CHUNK_SIZE:{CHUNK_SIZE} and M:{M}") if num_chunks == 1: (workspace13_shape, workspace2_shape, fused_out_shape, @@ -530,6 +532,7 @@ def forward( expert_kwargs = extra_expert_args or {} if num_chunks == 1: + # print('gggg'*100) fused_out = _resize_cache(workspace13, fused_out_shape) if 'topk_weights' in expert_kwargs and expert_kwargs['topk_weights'] is None: expert_kwargs['topk_weights'] = topk_weights @@ -557,6 +560,7 @@ def forward( else: # The leading output dimension may not be equal to M, so # we compute output indices separately. + # print('ttt'*100) M_out = fused_out_shape[0] assert M_out >= M factor = M_out // M @@ -582,6 +586,22 @@ def forward( end_chunk_idx) curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + if 'topk_weights' in expert_kwargs and expert_kwargs['topk_weights'] is not None: + expert_kwargs['topk_weights'] = topk_weights[begin_chunk_idx:end_chunk_idx] + assert expert_kwargs['topk_weights'] is not None + + if 'm' in expert_kwargs and expert_kwargs['m'] is not None: + expert_kwargs['m'] = end_chunk_idx - begin_chunk_idx + + + # if 'a1_scale' in expert_kwargs and expert_kwargs['a1_scale'] is not None: + # print("before swapping a1_scale"*10) + # print(f"a1_scale in kwargs:{expert_kwargs['a1_scale'].shape}") + # expert_kwargs['a1_scale'] = expert_kwargs['a1_scale'][begin_chunk_idx:end_chunk_idx] + # print(f"a1_scale in kwargs:{expert_kwargs['a1_scale'].shape}") + # print("after swapping a1_scale"*10) + # assert expert_kwargs['a1_scale'] is not None + self.fused_experts.apply( fused_out[begin_out_idx:end_out_idx], curr_a1q, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index fb631fb1e2b2..964250aa9dda 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -839,6 +839,8 @@ def apply( 'g1_alphas': layer.g1_alphas, 'g2_alphas': layer.g2_alphas, 'out_dtype': x.dtype, + # Avoid confusion with a1_scale and a2_scale whare are batch size + # related. 'a1_scale': torch.min(layer.w13_input_scale_quant), } extra_prepare_args = { @@ -881,8 +883,8 @@ def apply( w2_blockscale=layer.w2_blockscale_swizzled, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, - a1_scale=layer.w13_input_scale_quant, - a2_scale=layer.w2_input_scale_quant, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, topk_weights=topk_weights, topk_ids=topk_ids, m=x.shape[0], From 3f043ddfd0efcf284ae8641b38953bb696d6f3a9 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 10 Jul 2025 04:34:09 +0000 Subject: [PATCH 12/30] flahinfer cutlass moe support TP chunking --- .../layers/fused_moe/cutlass_moe.py | 55 ------------------- .../fused_moe/flashinfer_cutlass_moe.py | 12 ++-- .../flashinfer_cutlass_prepare_finalize.py | 7 ++- .../layers/fused_moe/modular_kernel.py | 4 +- .../layers/quantization/modelopt.py | 8 ++- 5 files changed, 17 insertions(+), 69 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d3488c435748..fdbaa259dbb6 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -417,21 +417,6 @@ def run_cutlass_moe_fp4( e: int, device: torch.device, ) -> None: - # print("---------------start-------------") - # print(f"output.shape: {output.shape}") # [m, n] - # print(f"a.shape: {a.shape}") # [m, k] - # print(f"a1_gscale.shape: {a1_gscale.shape}") # [1] - # print(f"w1_fp4.shape: {w1_fp4.shape}") # [e, k//2, n] - # print(f"w1_blockscale.shape: {w1_blockscale.shape}") # [e, n//groupsize] - # print(f"w1_alphas.shape: {w1_alphas.shape}") # [e, 1] - # print(f"a2_gscale.shape: {a2_gscale.shape}") # [1] - # print(f"w2_fp4.shape: {w2_fp4.shape}") # [e, n//2, k] - # print(f"w2_blockscale.shape: {w2_blockscale.shape}") # [e, k//groupsize] - # print(f"w2_alphas.shape: {w2_alphas.shape}") # [e, 1] - # print(f"topk_weights.shape: {topk_weights.shape}") # [m, top_k] - # print(f"topk_ids.shape: {topk_ids.shape}") # [m, top_k] - # print(f"m:{m}, n:{n}, k:{k}, e:{e}") - """ MoE implementation for FP4 Inputs @@ -475,8 +460,6 @@ def run_cutlass_moe_fp4( and k == k_w2), ("Hidden size mismatch between a, w1 and w2") assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in " "expected `n`") - # print(f"m:{m}; m_a:{m_a}") - # m = m_a assert (m == m_a), "input shape mismatch" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" @@ -503,7 +486,6 @@ def run_cutlass_moe_fp4( blockscale_offsets) a = ops.shuffle_rows(a, a_map) - # print(f"a:{a.shape}; a1_gscale:{a1_gscale.shape}") rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( a, a1_gscale, @@ -518,38 +500,21 @@ def run_cutlass_moe_fp4( w1_blockscale, w1_alphas, problem_sizes1, expert_offsets[:-1], blockscale_offsets[:-1], out_dtype, device) - # print(f"c1:{c1.shape}") del rep_a_fp4, rep_a_blockscale - # hidden size dimension is split to one halfpytho sized tensor. - # intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), - # device=device, - # dtype=out_dtype) - # print(f"intermediate:{intermediate.shape}") torch.ops._C.silu_and_mul(c2, c1) - # torch.ops._C.silu_and_mul(intermediate, c1) - # print(f"c2:{c2.shape}") - # print(f"a2_gscale:{a2_gscale.shape}") - # import pdb - # pdb.set_trace() int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk) ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale, w2_alphas, problem_sizes2, expert_offsets[:-1], blockscale_offsets[:-1], out_dtype, device) - # print(f"c2:{c2.shape}") del int_fp4, int_blockscale c3 = ops.shuffle_rows(c3, c_map) assert output.dtype == out_dtype - # print(f"output:{output.shape}") - # print(f"c3:{c3.view(m, num_topk, k).shape}") - # print(f"b:{topk_weights.view(m, num_topk, 1).half().shape}") - # print(f"g:{g.shape}") output.copy_((c3.view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half()).sum(dim=1), non_blocking=True) - # print("---------------end-------------") return @@ -591,26 +556,6 @@ def supports_expert_map(self) -> bool: def supports_chunking(self) -> bool: return True - - # def workspace_shapes( - # self, - # a: torch.Tensor, - # aq: torch.Tensor, - # M: int, - # N: int, - # K: int, - # topk: int, - # global_num_experts: int, - # local_num_experts: int, - # ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - # # Workspace1: for c1, Workspace2: for intermediate, Output: for final output - # workspace1 = a.shape - # # workspace 2 remains empty because run_cutlass_moe_fp4 allocates the - # # intermediate tensor there. - # # # workspace 1 is allocated to store the output. - # workspace2 = () - # output = a.shape - # return (workspace1, workspace2, output, self.out_dtype) def workspace_shapes( self, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 896d4201bf58..88e602557c21 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -85,7 +85,8 @@ def supports_expert_map(self) -> bool: def supports_chunking(self) -> bool: #TODO(shuw): support chunking later, actually support in layer.py - return False + # It means TP chunking + return True def workspace_shapes( self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, @@ -136,14 +137,15 @@ def apply( w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], # Not used workspace13:Optional[torch.Tensor], workspace2:Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor], topk_weights: torch.Tensor, g1_alphas: torch.Tensor, g2_alphas: torch.Tensor, - a1_scale: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, out_dtype: torch.dtype, ): # Flashinfer CUTLASS kernel takes scalar global scales, @@ -151,10 +153,10 @@ def apply( assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " "currently supported.") quant_scales = [ - a1_scale, + a1_gscale, w1_scale.view(torch.int32), g1_alphas, - a2_scale, + a2_gscale, w2_scale.view(torch.int32), g2_alphas, ] diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 7ac4d08eab39..c72d8153e689 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -54,14 +54,15 @@ def topk_indices_dtype(self) -> Optional[torch.dtype]: def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + a1_scale: Optional[torch.Tensor], # Not used + a2_scale: Optional[torch.Tensor], # Not used topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + a1_gscale: torch.Tensor, use_dp: Optional[bool] = True, local_tokens: int = -1, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], @@ -76,7 +77,7 @@ def prepare( a1q, a1q_scale = moe_kernel_quantize_input( a1, - a1_scale, + a1_gscale, quant_config.quant_dtype, self.per_channel_quant, self.block_shape, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 4d7605a6837e..581d462da813 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -461,8 +461,7 @@ def forward( global_num_experts = local_num_experts prepare_kwargs = extra_prepare_args or {} - # import pdb - # pdb.set_trace() + (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( a1, @@ -560,7 +559,6 @@ def forward( else: # The leading output dimension may not be equal to M, so # we compute output indices separately. - # print('ttt'*100) M_out = fused_out_shape[0] assert M_out >= M factor = M_out // M diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 964250aa9dda..233b8b76795c 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -830,6 +830,8 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) + a1_gscale = torch.min(layer.w13_input_scale_quant) + a2_gscale = torch.min(layer.w2_input_scale_quant) if self.allow_flashinfer_cutlass: # TP or DP case assert _valid_flashinfer_fused_moe( @@ -841,11 +843,13 @@ def apply( 'out_dtype': x.dtype, # Avoid confusion with a1_scale and a2_scale whare are batch size # related. - 'a1_scale': torch.min(layer.w13_input_scale_quant), + 'a1_gscale': a1_gscale, + 'a2_gscale': a2_gscale, } extra_prepare_args = { 'use_dp': layer.dp_size > 1, 'local_tokens': x.shape[0], + 'a1_gscale': a1_gscale, } extra_finalize_args = { 'use_dp': layer.dp_size > 1, @@ -864,8 +868,6 @@ def apply( expert_map=expert_map, w1_scale=layer.w13_blockscale_swizzled, w2_scale=layer.w2_blockscale_swizzled, - a1_scale=torch.min(layer.w13_input_scale_quant), - a2_scale=torch.min(layer.w2_input_scale_quant), extra_expert_args=extra_expert_args, extra_prepare_args=extra_prepare_args, extra_finalize_args=extra_finalize_args, From 7b5e203f8b7f15761693030e603f0c6d5ae9d416 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 10 Jul 2025 05:02:14 +0000 Subject: [PATCH 13/30] Address comment and clean Up --- vllm/distributed/device_communicators/all2all.py | 2 -- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 5 +++++ .../layers/fused_moe/flashinfer_cutlass_moe.py | 3 +-- vllm/model_executor/layers/fused_moe/layer.py | 5 ++--- .../layers/fused_moe/modular_kernel.py | 12 ------------ .../layers/fused_moe/prepare_finalize.py | 5 +---- vllm/model_executor/layers/quantization/modelopt.py | 1 - 7 files changed, 9 insertions(+), 24 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index b7beb70173dd..c9e1e1675c45 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -65,8 +65,6 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1] end = cu_tokens_across_dp_cpu[self.dp_rank] - print(f"dp_ws:{self.dp_group.world_size}; ws:{self.world_size}") - print(f"dp_rank:{self.dp_rank}; rank: {self.rank}") all_hidden_states = self.dp_group.all_reduce(hidden_states) hidden_states = all_hidden_states[start:end, :] return hidden_states diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index fdbaa259dbb6..39c5c87994c8 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -684,9 +684,14 @@ def cutlass_moe_fp4( 'e': e, 'device': device, } + + # NVFP4 requires two levels of quantization, which involves computing some scaling + # factors dynamically. This makes it incompatible with the typical + # prepare -> MoE -> finalize pipeline. Move the quantization logic into the MoE body. extra_prepare_args = { 'skip_quant': True, } + # Similar reason as above. extra_finalize_args = { 'skip_permute_reduce': True, } diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 88e602557c21..f6b06be9d868 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -84,8 +84,7 @@ def supports_expert_map(self) -> bool: return False def supports_chunking(self) -> bool: - #TODO(shuw): support chunking later, actually support in layer.py - # It means TP chunking + # This refers to TP chunking; DP chunking is handled separately. return True def workspace_shapes( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 373ba48e6182..a662ca6c512c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -933,7 +933,8 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. # w3, up_proj: Load into second logical weight of w13. - # trtllm cutlass kernel assumes differently + # The FlashInfer Cutlass fused MoE kernel expects the combined weights + # to be ordered as [w3, w1], unlike the standard [w1, w3] layout. assert shard_id in ("w1", "w3") switch_w13 = getattr(self.quant_method, 'load_up_proj_weight_first', False) @@ -1422,7 +1423,6 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): final_hidden_states, non_blocking=True) ctx = get_forward_context() - #TODO(shuw):where is it? # flashinfer_cutlass_kernels can handle: optional DP + TP/EP max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens @@ -1447,7 +1447,6 @@ def forward_impl(self, hidden_states: torch.Tensor, assert self.quant_method is not None # Route to the chunked forward path using the FlashInfer Cutlass kernel # only when data parallelism (DP) is enabled. - # TODO(shuw): Make TP calling also chunked. use_flashinfer_cutlass_kernels = self.dp_size > 1 and self.moe_parallel_config.use_flashinfer_cutlass_kernels if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels or diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 581d462da813..78c8087816a9 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -497,11 +497,9 @@ def forward( if self.fused_experts.enable_chunking(): CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE num_chunks = cdiv(M, CHUNK_SIZE) - # print(f"1CHUNK_SIZE:{CHUNK_SIZE} and M:{M}") else: CHUNK_SIZE = M num_chunks = 1 - # print(f"2CHUNK_SIZE:{CHUNK_SIZE} and M:{M}") if num_chunks == 1: (workspace13_shape, workspace2_shape, fused_out_shape, @@ -531,7 +529,6 @@ def forward( expert_kwargs = extra_expert_args or {} if num_chunks == 1: - # print('gggg'*100) fused_out = _resize_cache(workspace13, fused_out_shape) if 'topk_weights' in expert_kwargs and expert_kwargs['topk_weights'] is None: expert_kwargs['topk_weights'] = topk_weights @@ -591,15 +588,6 @@ def forward( if 'm' in expert_kwargs and expert_kwargs['m'] is not None: expert_kwargs['m'] = end_chunk_idx - begin_chunk_idx - - # if 'a1_scale' in expert_kwargs and expert_kwargs['a1_scale'] is not None: - # print("before swapping a1_scale"*10) - # print(f"a1_scale in kwargs:{expert_kwargs['a1_scale'].shape}") - # expert_kwargs['a1_scale'] = expert_kwargs['a1_scale'][begin_chunk_idx:end_chunk_idx] - # print(f"a1_scale in kwargs:{expert_kwargs['a1_scale'].shape}") - # print("after swapping a1_scale"*10) - # assert expert_kwargs['a1_scale'] is not None - self.fused_experts.apply( fused_out[begin_out_idx:end_out_idx], curr_a1q, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index ec67e6984df5..f1f805b755fc 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -47,9 +47,8 @@ def prepare( a1.mul_(topk_weights.to(a1.dtype)) if skip_quant: - # print("skiped_quant"*10) - # print(f"skiped_quant:{skip_quant}") return a1, None, None, None, None + a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) @@ -66,8 +65,6 @@ def finalize( apply_router_weight_on_input: bool, skip_permute_reduce: Optional[bool]=False, ) -> None: - # print("skip_permute_reduce"*10) - # print(f"skip_permute_reduce:{skip_permute_reduce}") if skip_permute_reduce: assert output.shape == fused_expert_output.shape output.copy_(fused_expert_output) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 233b8b76795c..eb6a21a7f53e 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -563,7 +563,6 @@ def select_gemm_impl(self, prepare_finalize, moe): from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts) logger.debug("FlashInferExperts %s", moe) - # assert moe.dp_size == all2all_manager.dp_world_size experts = FlashInferExperts( use_nvfp4_w4a4=True, use_dp=moe.moe_parallel_config.dp_size>1, From 131f1414d79bd0a6089d813e5f3a653c484dc3a9 Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 11 Jul 2025 20:07:56 +0000 Subject: [PATCH 14/30] Upd --- .../layers/fused_moe/cutlass_moe.py | 162 +++++++++++++++++- .../fused_moe/flashinfer_cutlass_moe.py | 13 +- .../flashinfer_cutlass_prepare_finalize.py | 6 + .../layers/fused_moe/modular_kernel.py | 34 ++-- .../layers/fused_moe/prepare_finalize.py | 5 +- 5 files changed, 202 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 2a9106226225..82de0b505ba2 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -443,6 +443,7 @@ def run_cutlass_moe_fp4( k: int, e: int, device: torch.device, + apply_router_weight_on_input: bool = False, ) -> None: """ MoE implementation for FP4 Inputs @@ -589,6 +590,10 @@ def supports_expert_map(self) -> bool: def supports_chunking(self) -> bool: return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() def workspace_shapes( self, @@ -634,7 +639,7 @@ def apply( a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], # start of extra_expert_args topk_weights: torch.Tensor, g1_alphas: torch.Tensor, @@ -649,7 +654,8 @@ def apply( ): assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " - "ModelOptNvFp4FusedMoE.") + "ModelOptNvFp4FusedMoE.") + run_cutlass_moe_fp4( output, hidden_states, @@ -726,7 +732,7 @@ def cutlass_moe_fp4( } # Similar reason as above. extra_finalize_args = { - 'skip_permute_reduce': True, + 'skip_weight_reduce': True, } return fn( hidden_states=a, @@ -747,3 +753,153 @@ def cutlass_moe_fp4( extra_prepare_args=extra_prepare_args, extra_finalize_args=extra_finalize_args, ) + + +def _valid_cutlass_block_scaled_grouped_gemm( + w1: torch.Tensor, w2: torch.Tensor, inplace: bool, activation: str, + apply_router_weight_on_input: bool, + expert_map: Optional[torch.Tensor]) -> bool: + + def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): + return N % 128 == 0 and K % 128 == 0 + + _, K, N = w2.size() + if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K): + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: unalinged problem size.") + return False + + if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).") + return False + + if expert_map is not None: + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: expert_parallel is" + " not supported.") + return False + + if activation != "silu": + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: only activation silu is" + " supported.") + return False + + if apply_router_weight_on_input: + logger.debug("CutlassBlockScaledGroupedGemm disabled:" + " apply_router_weight_on_input is not supported.") + return False + + if inplace: + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: inplace is not supported." + ) + return False + + return True + + +def run_cutlass_block_scaled_fused_experts( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + w1_q = w1.transpose(1, 2) + w2_q = w2.transpose(1, 2) + w1_scale = w1_scale.transpose(1, 2) + w2_scale = w2_scale.transpose(1, 2) + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert a.shape[0] == topk_ids.shape[ + 0], "a and topk_ids must have the same batch size" + assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn" + assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn" + assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" + assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[ + 0], "w1_scale expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[ + 0], "w2_scale expert number mismatch" + assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype" + + out_dtype = a.dtype + num_experts = w1_q.size(0) + m = a.size(0) + k = w1_q.size(1) + n = w2_q.size(1) + + expert_offsets = torch.empty((num_experts + 1, ), + dtype=torch.int32, + device="cuda") + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device="cuda") + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device="cuda") + + topk = topk_ids.size(1) + + a_q, a1_scale = _fp8_quantize(a, + A_scale=None, + per_act_token=False, + block_shape=[128, 128]) + device = a_q.device + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + ops.get_cutlass_moe_mm_data( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + num_experts, + n, + k, + ) + + rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) + rep_a1_scales = a1_scale[a_map] + + c1 = torch.empty((m * topk, n * 2), dtype=out_dtype, device=device) + c2 = torch.empty((m * topk, k), dtype=out_dtype, device=device) + + ops.cutlass_blockwise_scaled_grouped_mm( + c1, + rep_a_q, + w1_q, + rep_a1_scales, + w1_scale, + problem_sizes1, + expert_offsets[:-1], + ) + + intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device) + torch.ops._C.silu_and_mul(intermediate, c1) + + intermediate_q, a2_scale = _fp8_quantize(intermediate, + A_scale=None, + per_act_token=False, + block_shape=[128, 128]) + + ops.cutlass_blockwise_scaled_grouped_mm( + c2, + intermediate_q, + w2_q, + a2_scale, + w2_scale, + problem_sizes2, + expert_offsets[:-1], + ) + + return (c2[c_map].view(m, topk, k) * + topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index f6b06be9d868..673bf0c93c1b 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -9,6 +9,8 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.utils import round_up @@ -58,6 +60,8 @@ def __init__(self, ep_size: int=1, tp_rank: int=0, tp_size: int=1, + num_dispatchers: Optional[int] = None, + use_batched_format: bool = False, ): super().__init__( FusedMoEQuantConfig( @@ -72,6 +76,9 @@ def __init__(self, self.tp_rank=tp_rank self.tp_size=tp_size self.use_dp=use_dp + assert not use_batched_format or num_dispatchers is not None + self.num_dispatchers = num_dispatchers + @property def activation_formats( @@ -87,6 +94,10 @@ def supports_chunking(self) -> bool: # This refers to TP chunking; DP chunking is handled separately. return True + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + def workspace_shapes( self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int @@ -139,7 +150,7 @@ def apply( a2_scale: Optional[torch.Tensor], # Not used workspace13:Optional[torch.Tensor], workspace2:Optional[torch.Tensor], - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], topk_weights: torch.Tensor, g1_alphas: torch.Tensor, g2_alphas: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index c72d8153e689..7e5004690389 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -35,11 +35,13 @@ def __init__( quant_dtype: Optional[torch.dtype] = None, per_channel_quant: bool = False, block_shape: Optional[list[int]] = None, + num_dispatchers: int = 1, ): super().__init__() self.per_channel_quant = per_channel_quant self.block_shape = block_shape self.quant_dtype = quant_dtype + self.num_dispatchers_ = num_dispatchers @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -50,6 +52,9 @@ def max_num_tokens_per_rank(self) -> Optional[int]: def topk_indices_dtype(self) -> Optional[torch.dtype]: return None + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ def prepare( self, @@ -102,6 +107,7 @@ def finalize( topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, use_dp: bool = False, local_tokens: int = -1, ) -> None: diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 0d183949830f..c0e74d7aa3a1 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -461,7 +461,8 @@ def _do_fused_experts( w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata] + expert_tokens_meta: Optional[ExpertTokensMetadata], + extra_expert_kwargs: Optional[dict] = None ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -501,7 +502,8 @@ def _do_fused_experts( a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, - expert_tokens_meta=expert_tokens_meta) + expert_tokens_meta=expert_tokens_meta, + **extra_expert_kwargs) return fused_out @@ -523,10 +525,6 @@ def _maybe_chunk_fused_experts( CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE num_chunks = cdiv(M, CHUNK_SIZE) - if 'topk_weights' in extra_expert_kwargs and extra_expert_kwargs['topk_weights'] is None: - extra_expert_kwargs['topk_weights'] = topk_weights - assert extra_expert_kwargs['topk_weights'] is not None - if not self.fused_experts.supports_chunking() or num_chunks == 1: return self._do_fused_experts( fused_out=None, @@ -546,7 +544,7 @@ def _maybe_chunk_fused_experts( a1q_scale=a1q_scale, a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, - **extra_expert_kwargs) + extra_expert_kwargs=extra_expert_kwargs) # Chunking required case assert num_chunks > 1 @@ -599,6 +597,10 @@ def slice_expert_tokens_metadata( expert_num_tokens=c_expert_num_tokens, expert_num_tokens_cpu=c_expert_num_tokens_cpu) + topk_weights = extra_expert_kwargs.get('topk_weights') + m = extra_expert_kwargs.get('m') + + chunked_extra_expert_kwargs = extra_expert_kwargs for chunk_idx in range(num_chunks): c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids = ( slice_input_tensors(chunk_idx)) @@ -609,10 +611,13 @@ def slice_expert_tokens_metadata( expert_tokens_meta, c_topk_ids, local_num_experts, expert_map) - if 'm' in extra_expert_kwargs and extra_expert_kwargs['m'] is not None: - s = chunk_idx * CHUNK_SIZE - e = min(s + CHUNK_SIZE, M) - extra_expert_kwargs['m'] = e - s + s = chunk_idx * CHUNK_SIZE + e = min(s + CHUNK_SIZE, M) + + if m is not None: + chunked_extra_expert_kwargs['m'] = e - s + if topk_weights is not None: + chunked_extra_expert_kwargs['topk_weights'] = topk_weights[s:e] self._do_fused_experts(fused_out=slice_output_tensor(chunk_idx), a1=a1, @@ -631,7 +636,7 @@ def slice_expert_tokens_metadata( a1q_scale=c_a1q_scale, a2_scale=c_a2_scale, expert_tokens_meta=c_expert_tokens_meta, - **extra_expert_kwargs) + extra_expert_kwargs=chunked_extra_expert_kwargs) return fused_out @@ -726,6 +731,11 @@ def forward( fused_out = None extra_expert_kwargs = extra_expert_args or {} + + if 'topk_weights' in extra_expert_kwargs and extra_expert_kwargs['topk_weights'] is None: + extra_expert_kwargs['topk_weights'] = topk_weights + assert extra_expert_kwargs['topk_weights'] is not None + if a1q.numel() == 0: # This happens when none of the tokens from the all2all reach this # EP rank. Also, note that this is only relevant for CUDAGraph diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 76ac6e801a5a..a73b14e05949 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -38,6 +38,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + skip_quant: Optional[bool]=False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -66,9 +67,9 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - skip_permute_reduce: Optional[bool]=False, + skip_weight_reduce: Optional[bool]=False, ) -> None: - if skip_permute_reduce: + if skip_weight_reduce: assert output.shape == fused_expert_output.shape output.copy_(fused_expert_output) else: From 3bdbeb1e76e7604ce4398471b8ee12533f95b57b Mon Sep 17 00:00:00 2001 From: shuw Date: Sat, 12 Jul 2025 04:13:02 +0000 Subject: [PATCH 15/30] Fix lint --- vllm/_custom_ops.py | 13 +-- vllm/distributed/parallel_state.py | 2 +- .../model_executor/layers/fused_moe/config.py | 13 +-- .../layers/fused_moe/cutlass_moe.py | 25 ++--- .../fused_moe/flashinfer_cutlass_moe.py | 49 +++++----- .../flashinfer_cutlass_prepare_finalize.py | 19 ++-- vllm/model_executor/layers/fused_moe/layer.py | 22 ++--- .../layers/fused_moe/modular_kernel.py | 98 +++++++++++-------- .../layers/fused_moe/prepare_finalize.py | 6 +- vllm/model_executor/layers/fused_moe/utils.py | 4 +- .../layers/quantization/modelopt.py | 62 ++++++------ 11 files changed, 163 insertions(+), 150 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f85cc115efe5..1117e825198a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -958,9 +958,9 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, - b_tensors: torch.Tensor, - a_scales: torch.Tensor, b_scales: torch.Tensor, - alphas: torch.Tensor, problem_sizes: torch.Tensor, + b_tensors: torch.Tensor, a_scales: torch.Tensor, + b_scales: torch.Tensor, alphas: torch.Tensor, + problem_sizes: torch.Tensor, expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, out_dtype: torch.dtype, device: torch.device): """ @@ -979,9 +979,10 @@ def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. """ - return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors, a_scales, - b_scales, alphas, problem_sizes, - expert_offsets, sf_offsets) + return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors, + a_scales, b_scales, alphas, + problem_sizes, expert_offsets, + sf_offsets) # aqlm diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d1a525c140f7..ed6a40cb375e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -370,7 +370,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: return input_ assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - + # TODO(shuw): enable it if self.use_custom_op_call and False: return torch.ops.vllm.all_gather(input_, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 9e6a3cdb0657..d5d4eba950a3 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import torch from compressed_tensors.quantization import (QuantizationArgs, @@ -16,8 +16,6 @@ QuantizationConfig) from vllm.utils import cdiv -from typing import TYPE_CHECKING - try: from flashinfer import fp4_quantize as fp4_quantize from flashinfer.fused_moe import ( @@ -188,6 +186,7 @@ def use_all2all_kernels(self): def use_pplx_kernels(self): return (self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "pplx") + @property def use_deepep_ht_kernels(self): return (self.use_all2all_kernels @@ -446,9 +445,11 @@ def make( from vllm.model_executor.layers.quantization.fp8 import Fp8Config if quant_dtype is None and isinstance(quant_config, Fp8Config): quant_dtype = torch.float8_e4m3fn - - from vllm.model_executor.layers.quantization.modelopt import ModelOptNvFp4Config - if quant_dtype is None and isinstance(quant_config, ModelOptNvFp4Config): + + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4Config) + if quant_dtype is None and isinstance(quant_config, + ModelOptNvFp4Config): quant_dtype = torch.uint8 if weight_quant is not None: diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 82de0b505ba2..d872b552f656 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -487,7 +487,7 @@ def run_cutlass_moe_fp4( assert (k_a == half_k_w1 * 2 and k == k_w2), ("Hidden size mismatch between a, w1 and w2") assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in " - "expected `n`") + "expected `n`") assert (m == m_a), "input shape mismatch" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" @@ -531,28 +531,30 @@ def run_cutlass_moe_fp4( c2 = _resize_cache(workspace2, (m * topk, n)) c3 = _resize_cache(workspace13, (m * topk, k)) ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale, - w1_blockscale, w1_alphas, problem_sizes1, - expert_offsets[:-1], blockscale_offsets[:-1], - out_dtype, device) + w1_blockscale, w1_alphas, problem_sizes1, + expert_offsets[:-1], blockscale_offsets[:-1], + out_dtype, device) del rep_a_fp4, rep_a_blockscale torch.ops._C.silu_and_mul(c2, c1) int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk) ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale, - w2_alphas, problem_sizes2, expert_offsets[:-1], - blockscale_offsets[:-1], out_dtype, device) + w2_alphas, problem_sizes2, expert_offsets[:-1], + blockscale_offsets[:-1], out_dtype, device) del int_fp4, int_blockscale c3 = ops.shuffle_rows(c3, c_map) assert output.dtype == out_dtype output.copy_((c3.view(m, num_topk, k) * - topk_weights.view(m, num_topk, 1).half()).sum(dim=1), non_blocking=True) + topk_weights.view(m, num_topk, 1).half()).sum(dim=1), + non_blocking=True) return class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( self, max_experts_per_worker: int, @@ -568,8 +570,7 @@ def __init__( per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_out_ch_quant, block_shape=block_shape, - ) - ) + )) self.max_experts_per_worker = max_experts_per_worker self.out_dtype = out_dtype self.use_batched_format = use_batched_format @@ -594,7 +595,7 @@ def supports_chunking(self) -> bool: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceDelegate() - + def workspace_shapes( self, a: torch.Tensor, @@ -708,7 +709,7 @@ def cutlass_moe_fp4( out_dtype=a.dtype, per_act_token_quant=False, per_out_ch_quant=False, - use_batched_format=False, + use_batched_format=False, ), ) extra_expert_args = { @@ -725,7 +726,7 @@ def cutlass_moe_fp4( } # NVFP4 requires two levels of quantization, which involves computing some scaling - # factors dynamically. This makes it incompatible with the typical + # factors dynamically. This makes it incompatible with the typical # prepare -> MoE -> finalize pipeline. Move the quantization logic into the MoE body. extra_prepare_args = { 'skip_quant': True, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 673bf0c93c1b..355db1b72554 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,26 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Dict +from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( - FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) -from vllm.utils import round_up - logger = init_logger(__name__) from typing import TYPE_CHECKING try: from flashinfer import fp4_quantize as fp4_quantize - from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe + from flashinfer.fused_moe import ( + cutlass_fused_moe as flashinfer_cutlass_fused_moe) except ImportError: if not TYPE_CHECKING: cutlass_fused_moe = None @@ -50,16 +47,18 @@ def _valid_flashinfer_fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, return False return True + class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, + def __init__( + self, use_nvfp4_w4a4: bool = False, use_fp8_w8a8: bool = False, - use_dp: bool=False, - ep_rank: int=0, - ep_size: int=1, - tp_rank: int=0, - tp_size: int=1, + use_dp: bool = False, + ep_rank: int = 0, + ep_size: int = 1, + tp_rank: int = 0, + tp_size: int = 1, num_dispatchers: Optional[int] = None, use_batched_format: bool = False, ): @@ -71,15 +70,14 @@ def __init__(self, )) self.use_nvfp4_w4a4 = use_nvfp4_w4a4 self.use_fp8_w8a8 = use_fp8_w8a8 - self.ep_rank=ep_rank - self.ep_size=ep_size - self.tp_rank=tp_rank - self.tp_size=tp_size - self.use_dp=use_dp + self.ep_rank = ep_rank + self.ep_size = ep_size + self.tp_rank = tp_rank + self.tp_size = tp_size + self.use_dp = use_dp assert not use_batched_format or num_dispatchers is not None self.num_dispatchers = num_dispatchers - @property def activation_formats( self @@ -89,7 +87,7 @@ def activation_formats( def supports_expert_map(self) -> bool: return False - + def supports_chunking(self) -> bool: # This refers to TP chunking; DP chunking is handled separately. return True @@ -97,7 +95,7 @@ def supports_chunking(self) -> bool: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceDelegate() - + def workspace_shapes( self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int @@ -121,17 +119,16 @@ def workspace_shapes( of each tuple must be the number of tokens. """ assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " - "currently supported.") + "currently supported.") aq_m, aq_n = aq.shape workspace2 = () output_shape = (aq_m, aq_n * 2) workspace_dtype = a.dtype workspace1 = output_shape - # The workspace is determined by `aq`, since it comes after any + # The workspace is determined by `aq`, since it comes after any # potential communication op and is involved in the expert computation. return (workspace1, workspace2, output_shape, workspace_dtype) - def apply( self, output: torch.Tensor, @@ -148,8 +145,8 @@ def apply( w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], # Not used - workspace13:Optional[torch.Tensor], - workspace2:Optional[torch.Tensor], + workspace13: Optional[torch.Tensor], + workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], topk_weights: torch.Tensor, g1_alphas: torch.Tensor, @@ -185,4 +182,4 @@ def apply( ep_size=self.ep_size, ep_rank=self.ep_rank, output=output, - ) \ No newline at end of file + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 7e5004690389..996825273fc6 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -1,18 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional -import vllm.envs as envs import torch -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.distributed import get_dp_group from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + def get_local_sizes(local_tokens): cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu @@ -22,14 +21,16 @@ def get_local_sizes(local_tokens): max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE sizes_chunked = [max_num_tokens] * len(sizes) if local_tokens < max_num_tokens: - # When the number of local tokens is less than max_num_tokens, all other - # ranks will also have fewer than max_num_tokens. The remaining tokens + # When the number of local tokens is less than max_num_tokens, all other + # ranks will also have fewer than max_num_tokens. The remaining tokens # are accounted for as residual. sizes_chunked = [x % max_num_tokens for x in sizes] return sizes_chunked + class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + def __init__( self, quant_dtype: Optional[torch.dtype] = None, @@ -37,7 +38,7 @@ def __init__( block_shape: Optional[list[int]] = None, num_dispatchers: int = 1, ): - super().__init__() + super().__init__() self.per_channel_quant = per_channel_quant self.block_shape = block_shape self.quant_dtype = quant_dtype @@ -52,7 +53,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]: def topk_indices_dtype(self) -> Optional[torch.dtype]: return None - + def num_dispatchers(self) -> int: return self.num_dispatchers_ @@ -67,7 +68,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - a1_gscale: torch.Tensor, + a1_gscale: torch.Tensor, use_dp: Optional[bool] = True, local_tokens: int = -1, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d69f063f3230..4635d7715a88 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Iterable from enum import Enum -from typing import Callable, Literal, Optional, overload +from typing import TYPE_CHECKING, Callable, Literal, Optional, overload import torch import torch.nn.functional as F @@ -35,8 +35,6 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx -from typing import TYPE_CHECKING - try: from flashinfer import fp4_quantize as fp4_quantize from flashinfer.fused_moe import ( @@ -57,7 +55,8 @@ from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) if has_flashinfer: - from .flashinfer_cutlass_prepare_finalize import FlashInferCutlassMoEPrepareAndFinalize + from .flashinfer_cutlass_prepare_finalize import ( + FlashInferCutlassMoEPrepareAndFinalize) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -107,8 +106,7 @@ def maybe_make_prepare_finalize( if moe.use_flashinfer_cutlass_kernels: prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize( - quant_dtype=moe.quant_dtype, - ) + quant_dtype=moe.quant_dtype, ) if moe.use_pplx_kernels: hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, @@ -755,9 +753,9 @@ def __init__( quant_method: Optional[QuantizeMethodBase] = None quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None else quant_config.get_quant_method(self, prefix)) - + quant_method.select_experts_impl(self.moe_parallel_config) - + assert quant_method is not None assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method @@ -932,7 +930,7 @@ def _load_w13(self, # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. # w3, up_proj: Load into second logical weight of w13. - # The FlashInfer Cutlass fused MoE kernel expects the combined weights + # The FlashInfer Cutlass fused MoE kernel expects the combined weights # to be ordered as [w3, w1], unlike the standard [w1, w3] layout. assert shard_id in ("w1", "w3") switch_w13 = getattr(self.quant_method, 'load_up_proj_weight_first', @@ -1472,12 +1470,12 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None - # Route to the chunked forward path using the FlashInfer Cutlass kernel + # Route to the chunked forward path using the FlashInfer Cutlass kernel # only when data parallelism (DP) is enabled. use_flashinfer_cutlass_kernels = self.dp_size > 1 and self.moe_parallel_config.use_flashinfer_cutlass_kernels if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels or - use_flashinfer_cutlass_kernels): + or self.moe_parallel_config.use_deepep_ll_kernels + or use_flashinfer_cutlass_kernels): return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index c0e74d7aa3a1..8fd868364491 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -453,17 +453,25 @@ def __init__( f"{fused_experts.activation_formats[0]}") def _do_fused_experts( - self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, - a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - local_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], + self, + fused_out: Optional[torch.Tensor], + a1: torch.Tensor, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], - extra_expert_kwargs: Optional[dict] = None - ) -> torch.Tensor: + extra_expert_kwargs: Optional[dict] = None) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -508,16 +516,24 @@ def _do_fused_experts( return fused_out def _maybe_chunk_fused_experts( - self, a1: torch.Tensor, a1q: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_ids: torch.Tensor, activation: str, - global_num_experts: int, local_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata], - extra_expert_kwargs: Optional[dict] = None, + self, + a1: torch.Tensor, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + extra_expert_kwargs: Optional[dict] = None, ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -610,7 +626,7 @@ def slice_expert_tokens_metadata( c_expert_tokens_meta = slice_expert_tokens_metadata( expert_tokens_meta, c_topk_ids, local_num_experts, expert_map) - + s = chunk_idx * CHUNK_SIZE e = min(s + CHUNK_SIZE, M) @@ -619,24 +635,25 @@ def slice_expert_tokens_metadata( if topk_weights is not None: chunked_extra_expert_kwargs['topk_weights'] = topk_weights[s:e] - self._do_fused_experts(fused_out=slice_output_tensor(chunk_idx), - a1=a1, - a1q=c_a1q, - w1=w1, - w2=w2, - topk_ids=c_topk_ids, - activation=activation, - global_num_experts=global_num_experts, - local_num_experts=local_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=c_a1q_scale, - a2_scale=c_a2_scale, - expert_tokens_meta=c_expert_tokens_meta, - extra_expert_kwargs=chunked_extra_expert_kwargs) + self._do_fused_experts( + fused_out=slice_output_tensor(chunk_idx), + a1=a1, + a1q=c_a1q, + w1=w1, + w2=w2, + topk_ids=c_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=c_a1q_scale, + a2_scale=c_a2_scale, + expert_tokens_meta=c_expert_tokens_meta, + extra_expert_kwargs=chunked_extra_expert_kwargs) return fused_out @@ -731,8 +748,9 @@ def forward( fused_out = None extra_expert_kwargs = extra_expert_args or {} - - if 'topk_weights' in extra_expert_kwargs and extra_expert_kwargs['topk_weights'] is None: + + if 'topk_weights' in extra_expert_kwargs and extra_expert_kwargs[ + 'topk_weights'] is None: extra_expert_kwargs['topk_weights'] = topk_weights assert extra_expert_kwargs['topk_weights'] is not None @@ -763,7 +781,7 @@ def forward( a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, extra_expert_kwargs=extra_expert_kwargs, - ) + ) extra_finalize_kwargs = extra_finalize_args or {} self.prepare_finalize.finalize( diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index a73b14e05949..178174b5a15f 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -38,7 +38,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - skip_quant: Optional[bool]=False, + skip_quant: Optional[bool] = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -49,7 +49,7 @@ def prepare( assert topk == 1, \ "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - + if skip_quant: return a1, None, None, None, None @@ -67,7 +67,7 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - skip_weight_reduce: Optional[bool]=False, + skip_weight_reduce: Optional[bool] = False, ) -> None: if skip_weight_reduce: assert output.shape == fused_expert_output.shape diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 6cf40270d39f..593d09266c41 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -195,7 +195,9 @@ def moe_kernel_quantize_input( elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.uint8: # nvfp4 - return _fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scalar_swizzled) + return _fp4_quantize(A, + A_scale, + is_sf_swizzled_layout=is_fp4_scalar_swizzled) elif quant_dtype == "mxfp4": return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) else: diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 39d53a74c99b..1b530cf2bfaf 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch @@ -9,10 +8,15 @@ from torch.nn.parameter import Parameter import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import (cutlass_scaled_fp4_mm, cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm.distributed import get_ep_group from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, _valid_flashinfer_fused_moe) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( + FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -32,17 +36,7 @@ PerTensorScaleParameter) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, _valid_flashinfer_fused_moe) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( - FlashInferCutlassMoEPrepareAndFinalize) -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.distributed import ( - get_dp_group, get_ep_group, get_tensor_model_parallel_world_size) - -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) try: from flashinfer import fp4_quantize as fp4_quantize from flashinfer.fused_moe import ( @@ -485,7 +479,7 @@ def __init__(self, quant_config: ModelOptNvFp4Config): self.cutlass_nvfp4_supported = cutlass_fp4_supported() self.use_marlin = False self.allow_flashinfer_cutlass = False - + if envs.VLLM_USE_FLASHINFER_MOE: if self.cutlass_nvfp4_supported and current_platform.is_cuda() \ and current_platform.has_device_capability(10, 0): @@ -506,21 +500,21 @@ def __init__(self, quant_config: ModelOptNvFp4Config): " above.") from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp4) - + self.fused_experts = cutlass_moe_fp4 @property def load_up_proj_weight_first(self) -> bool: # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 return self.allow_flashinfer_cutlass - - def select_experts_impl(self, moe_parallel_config): + + def select_experts_impl(self, moe_parallel_config): if not self.allow_flashinfer_cutlass: # if moe_parallel_config.dp_size > 1: # raise ValueError("CutlassExpertsFp4 Doesn't support DP. " # "Use flashinfer CUTLASS FusedMoE backend instead.") return - + logger.debug("FlashInferExperts") # default to TP/EP case only @@ -536,7 +530,7 @@ def select_experts_impl(self, moe_parallel_config): experts = FlashInferExperts(**experts_kwargs) self.fused_experts = mk.FusedMoEModularKernel( FlashInferCutlassMoEPrepareAndFinalize( - quant_dtype=torch.uint8, + quant_dtype=torch.uint8, #meaning 2x e2m1 packed in one, kernel requirement ), experts, @@ -548,7 +542,7 @@ def load_up_proj_weight_first(self) -> bool: if self.allow_flashinfer_cutlass: return True return False - + # This method update self.fused_experts # only prepare_finalize is not None call select_gemm_impl # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert @@ -566,18 +560,19 @@ def select_gemm_impl(self, prepare_finalize, moe): logger.debug("FlashInferExperts %s", moe) experts = FlashInferExperts( use_nvfp4_w4a4=True, - use_dp=moe.moe_parallel_config.dp_size>1, + use_dp=moe.moe_parallel_config.dp_size > 1, ep_rank=moe.moe_parallel_config.ep_rank, ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, tp_size=moe.moe_parallel_config.tp_size, ) else: - assert moe.dp_size > 1 + assert moe.dp_size > 1 logger.debug("CutlassExpertsFp4 %s", moe) # current doesn't support DP - raise ValueError("CutlassExpertsFp4 Doesn't support DP. " - "Use flashinfer CUTLASS FusedMoE backend instead.") + raise ValueError( + "CutlassExpertsFp4 Doesn't support DP. " + "Use flashinfer CUTLASS FusedMoE backend instead.") return experts @@ -819,16 +814,17 @@ def apply( if self.allow_flashinfer_cutlass: # TP or DP case assert _valid_flashinfer_fused_moe( - x, layer.w13_weight, layer.w2_weight), ("Flashinfer CUTLASS Fused MoE not applicable!") + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") extra_expert_args = { - 'topk_weights': None, # placeholder topk_weights, - 'g1_alphas': layer.g1_alphas, - 'g2_alphas': layer.g2_alphas, - 'out_dtype': x.dtype, - # Avoid confusion with a1_scale and a2_scale whare are batch size - # related. - 'a1_gscale': a1_gscale, - 'a2_gscale': a2_gscale, + 'topk_weights': None, # placeholder topk_weights, + 'g1_alphas': layer.g1_alphas, + 'g2_alphas': layer.g2_alphas, + 'out_dtype': x.dtype, + # Avoid confusion with a1_scale and a2_scale whare are batch size + # related. + 'a1_gscale': a1_gscale, + 'a2_gscale': a2_gscale, } extra_prepare_args = { 'use_dp': layer.dp_size > 1, @@ -857,8 +853,6 @@ def apply( extra_finalize_args=extra_finalize_args, ) else: - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - run_cutlass_moe_fp4) # cutlass_moe_fp4, TP case only(no EP) out = self.fused_experts( From 8fdb3dc5985d72793430424914dad697daf01ae4 Mon Sep 17 00:00:00 2001 From: shuw Date: Sat, 12 Jul 2025 04:34:36 +0000 Subject: [PATCH 16/30] Merge Pynccl ag/rs --- benchmarks/benchmark_throughput.py | 3 --- tests/distributed/test_pynccl.py | 2 +- .../device_communicators/all2all.py | 1 + .../device_communicators/cuda_communicator.py | 24 ++++++++++++++++--- .../device_communicators/pynccl_wrapper.py | 2 +- vllm/distributed/parallel_state.py | 13 ++++------ 6 files changed, 29 insertions(+), 16 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 7db41916996e..14461121fece 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -28,7 +28,6 @@ VisionArenaDataset, ) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json -from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args, @@ -111,8 +110,6 @@ def run_vllm( ), ) end = time.perf_counter() - - cleanup_dist_env_and_memory() return end - start, outputs diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a646bb25fc17..abfad9ebfe7d 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -399,4 +399,4 @@ def test_ncclGetUniqueId(): # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] # as long as the function doesn't raise an exception, we're good - assert unique_id is not None \ No newline at end of file + assert unique_id is not None diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index c9e1e1675c45..85f87cb21edc 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -65,6 +65,7 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1] end = cu_tokens_across_dp_cpu[self.dp_rank] + all_hidden_states = self.dp_group.all_reduce(hidden_states) hidden_states = all_hidden_states[start:end, :] return hidden_states diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 3e4a52c54559..e4804691f0f6 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform from .base_device_communicator import DeviceCommunicatorBase @@ -41,6 +42,8 @@ def __init__(self, CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) + from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickAllReduce) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -50,6 +53,7 @@ def __init__(self, ) self.ca_comm: Optional[CustomAllreduce] = None + self.qr_comm: Optional[QuickAllReduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -57,6 +61,14 @@ def __init__(self, device=self.device, ) + if current_platform.is_rocm(): + # Initialize a custom quick all-reduce implementation for AMD. + # Quick reduce is designed as a complement to custom allreduce. + # Based on quickreduce (https://github.com/mk1-project/quickreduce). + # If it's a rocm, 'use_custom_allreduce==True' means it must + # currently be an MI300 series. + self.qr_comm = QuickAllReduce(group=self.cpu_group, + device=self.device) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -79,8 +91,14 @@ def __init__(self, raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): - # always try custom allreduce first, - # and then pynccl. + # always try quick reduce first, then custom allreduce, + # and then pynccl. (quick reduce just for ROCM MI3*) + qr_comm = self.qr_comm + if qr_comm is not None and not qr_comm.disabled and \ + qr_comm.should_quick_allreduce(input_): + out = qr_comm.quick_all_reduce(input_) + assert out is not None + return out ca_comm = self.ca_comm if ca_comm is not None and not ca_comm.disabled and \ ca_comm.should_custom_ar(input_): @@ -254,4 +272,4 @@ def dispatch( def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: assert self.all2all_manager is not None hidden_states = self.all2all_manager.combine(hidden_states) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 8efda853c05a..a930b63bc26f 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -379,4 +379,4 @@ def ncclGroupEnd(self) -> None: __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", "ncclComm_t", "cudaStream_t", "buffer_type" -] \ No newline at end of file +] diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ed6a40cb375e..1bb0ca79cc1d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -30,7 +30,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Optional, Union from unittest.mock import patch import torch @@ -371,8 +371,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - # TODO(shuw): enable it - if self.use_custom_op_call and False: + if self.use_custom_op_call: return torch.ops.vllm.all_gather(input_, dim, world_size, @@ -392,8 +391,7 @@ def all_gatherv(self, def reduce_scatter(self, input_: torch.Tensor, - dim: int = -1, - sizes: Optional[List[int]] = None) -> torch.Tensor: + dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: @@ -401,14 +399,13 @@ def reduce_scatter(self, assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - if self.use_custom_op_call and False: - assert sizes is None, "Varying size reduce scatter not supported with vllm custom op" + if self.use_custom_op_call: return torch.ops.vllm.reduce_scatter(input_, dim, world_size, group_name=self.unique_name) else: - return self._reduce_scatter_out_place(input_, dim, sizes) + return self._reduce_scatter_out_place(input_, dim) def reduce_scatterv(self, input_: torch.Tensor, From 5bcd50bff2d9f025f33e268ee632beb57178ffc0 Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 15 Jul 2025 04:22:50 +0000 Subject: [PATCH 17/30] Recover apply_router_weight_on_input --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 5 ++++- vllm/model_executor/layers/quantization/modelopt.py | 7 ++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d872b552f656..0f1e6f381805 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -652,6 +652,7 @@ def apply( k: int, e: int, device: torch.device, + apply_router_weight_on_input, ): assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " @@ -677,6 +678,7 @@ def apply( k, e, device, + apply_router_weight_on_input, ) @@ -698,6 +700,7 @@ def cutlass_moe_fp4( e: int, device: torch.device, expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False ) -> torch.Tensor: assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " @@ -749,7 +752,7 @@ def cutlass_moe_fp4( w2_scale=w2_blockscale, a1_scale=None, a2_scale=None, - apply_router_weight_on_input=False, + apply_router_weight_on_input=apply_router_weight_on_input, extra_expert_args=extra_expert_args, extra_prepare_args=extra_prepare_args, extra_finalize_args=extra_finalize_args, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 1b530cf2bfaf..77416f5d8f74 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -482,7 +482,7 @@ def __init__(self, quant_config: ModelOptNvFp4Config): if envs.VLLM_USE_FLASHINFER_MOE: if self.cutlass_nvfp4_supported and current_platform.is_cuda() \ - and current_platform.has_device_capability(10, 0): + and current_platform.is_device_capability(100): logger.info_once( "Using FlashInfer kernels for ModelOptNvFp4FusedMoE.") self.allow_flashinfer_cutlass = True @@ -510,9 +510,6 @@ def load_up_proj_weight_first(self) -> bool: def select_experts_impl(self, moe_parallel_config): if not self.allow_flashinfer_cutlass: - # if moe_parallel_config.dp_size > 1: - # raise ValueError("CutlassExpertsFp4 Doesn't support DP. " - # "Use flashinfer CUTLASS FusedMoE backend instead.") return logger.debug("FlashInferExperts") @@ -853,7 +850,6 @@ def apply( extra_finalize_args=extra_finalize_args, ) else: - # cutlass_moe_fp4, TP case only(no EP) out = self.fused_experts( a=x, @@ -873,5 +869,6 @@ def apply( e=layer.w13_weight.shape[0], device=x.device, expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input ) return out From 6fed4949e198f320a6e90199cd5ca3ba00219337 Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 15 Jul 2025 17:53:21 +0000 Subject: [PATCH 18/30] remove comment --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 4 ---- vllm/model_executor/layers/quantization/modelopt.py | 1 - 2 files changed, 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 9567b668dc02..a4ed10c24896 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -624,7 +624,6 @@ def slice_expert_tokens_metadata( expert_num_tokens=c_expert_num_tokens, expert_num_tokens_cpu=c_expert_num_tokens_cpu) - # topk_weights = extra_expert_kwargs.get('topk_weights') m = extra_expert_kwargs.get('m') chunked_extra_expert_kwargs = extra_expert_kwargs @@ -643,9 +642,6 @@ def slice_expert_tokens_metadata( if m is not None: chunked_extra_expert_kwargs['m'] = e - s - # if topk_weights is not None: - # chunked_extra_expert_kwargs['topk_weights'] = topk_weights[s:e] - self._do_fused_experts( fused_out=slice_output_tensor(chunk_idx), a1=a1, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index db5bda51ce22..abddf313f7b8 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1070,7 +1070,6 @@ def apply( x, layer.w13_weight, layer.w2_weight), ( "Flashinfer CUTLASS Fused MoE not applicable!") extra_expert_args = { - # 'topk_weights': None, # placeholder topk_weights, 'g1_alphas': layer.g1_alphas, 'g2_alphas': layer.g2_alphas, 'out_dtype': x.dtype, From 20f3417d5de518c23c9a8a818605877048c449b5 Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 15 Jul 2025 18:05:30 +0000 Subject: [PATCH 19/30] fix lint --- .../layers/fused_moe/cutlass_moe.py | 41 +++++++++---------- .../fused_moe/flashinfer_cutlass_moe.py | 2 +- .../layers/fused_moe/modular_kernel.py | 38 +++++++++-------- .../layers/quantization/modelopt.py | 3 +- 4 files changed, 44 insertions(+), 40 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 95a539becc67..f53b1d4c788f 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -643,7 +643,7 @@ def apply( k: int, e: int, device: torch.device, - apply_router_weight_on_input, + apply_router_weight_on_input, ): assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " @@ -669,30 +669,29 @@ def apply( k, e, device, - apply_router_weight_on_input, + apply_router_weight_on_input, ) def cutlass_moe_fp4( - a: torch.Tensor, - w1_fp4: torch.Tensor, - w2_fp4: torch.Tensor, - w1_blockscale: torch.Tensor, - w2_blockscale: torch.Tensor, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - m: int, - n: int, - k: int, - e: int, - device: torch.device, - expert_map: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False -) -> torch.Tensor: + a: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False) -> torch.Tensor: assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " "ModelOptNvFp4FusedMoE's cutlass_moe_fp4.") diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 07146e31cddb..5762abf031cb 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -149,7 +149,7 @@ def apply( workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: Optional[bool], # Not used + apply_router_weight_on_input: Optional[bool], # Not used g1_alphas: torch.Tensor, g2_alphas: torch.Tensor, a1_gscale: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a4ed10c24896..a0e444fd171e 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -459,22 +459,28 @@ def __init__( f"{fused_experts.__class__.__name__}." f"{fused_experts.activation_formats[0]}") - def _do_fused_experts(self, fused_out: Optional[torch.Tensor], - a1: torch.Tensor, a1q: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, global_num_experts: int, - local_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_kwargs: Optional[dict] = None) -> torch.Tensor: + def _do_fused_experts( + self, + fused_out: Optional[torch.Tensor], + a1: torch.Tensor, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_kwargs: Optional[dict] = None) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index abddf313f7b8..69c3aaf5f757 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1124,6 +1124,5 @@ def apply( e=layer.w13_weight.shape[0], device=x.device, expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input - ) + apply_router_weight_on_input=apply_router_weight_on_input) return out From 728275bc87fbd708ffabb5f8907b9193929e6b7e Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 15 Jul 2025 21:42:11 +0000 Subject: [PATCH 20/30] Add autotune --- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index af216539c900..9c1f1750c252 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2021,12 +2021,12 @@ def _dummy_run( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - + from flashinfer import AutoTuner, autotune with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + num_tokens_across_dp=num_tokens_across_dp), autotune(): outputs = model( input_ids=input_ids, positions=positions, From bbb505f78e373df606096ab3cf691aae362f9d59 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 15 Jul 2025 22:38:17 -0400 Subject: [PATCH 21/30] Add flashinfer wrapper and fix pre-commit Signed-off-by: mgoin --- .../model_executor/layers/fused_moe/config.py | 15 +-- .../fused_moe/flashinfer_cutlass_moe.py | 40 +++--- .../flashinfer_cutlass_prepare_finalize.py | 4 +- vllm/model_executor/layers/fused_moe/layer.py | 18 +-- vllm/model_executor/layers/fused_moe/utils.py | 2 +- .../layers/quantization/modelopt.py | 37 ++---- vllm/utils/flashinfer.py | 117 ++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 4 +- 8 files changed, 162 insertions(+), 75 deletions(-) create mode 100644 vllm/utils/flashinfer.py diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index d5d4eba950a3..b0fb562b5afc 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional, Union import torch from compressed_tensors.quantization import (QuantizationArgs, @@ -15,15 +15,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.utils import cdiv - -try: - from flashinfer import fp4_quantize as fp4_quantize - from flashinfer.fused_moe import ( - cutlass_fused_moe as flashinfer_cutlass_fused_moe) -except ImportError: - if not TYPE_CHECKING: - flashinfer_cutlass_fused_moe = None -has_flashinfer = flashinfer_cutlass_fused_moe is not None +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe logger = init_logger(__name__) @@ -199,7 +191,8 @@ def use_deepep_ll_kernels(self): @property def use_flashinfer_cutlass_kernels(self): - return envs.VLLM_USE_FLASHINFER_MOE and has_flashinfer + return (envs.VLLM_USE_FLASHINFER_MOE + and has_flashinfer_cutlass_fused_moe()) @staticmethod def make(tp_size_: int, dp_size_: int, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 5762abf031cb..3b236aeb8fae 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -9,41 +9,31 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) +from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, + has_flashinfer_cutlass_fused_moe) logger = init_logger(__name__) -from typing import TYPE_CHECKING -try: - from flashinfer import fp4_quantize as fp4_quantize - from flashinfer.fused_moe import ( - cutlass_fused_moe as flashinfer_cutlass_fused_moe) -except ImportError: - if not TYPE_CHECKING: - cutlass_fused_moe = None - -has_flashinfer_cutlass_fused_moe = flashinfer_cutlass_fused_moe is not None - - -def _valid_flashinfer_fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor) -> bool: +def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor) -> bool: """ Check if the given problem size is supported by the FlashInfer CUTLASS MoE kernel. """ - if not has_flashinfer_cutlass_fused_moe: - logger.debug( - "FlashInferExperts disabled: flashinfer_cutlass_fused_moe not available." - ) + if not has_flashinfer_cutlass_fused_moe(): + logger.debug_once("FlashInferExperts disabled: " + "flashinfer_cutlass_fused_moe not available.") return False # Data type checks if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8 or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16]): - logger.debug( - f"FlashInferExperts disabled: w1/w2 must be torch.uint8 (got w1={w1.dtype}, w2={w2.dtype}), " - f"hidden_states must be float32, float16, or bfloat16 (got {hidden_states.dtype})." - ) + logger.debug_once( + "FlashInferExperts disabled: w1/w2 must be torch.uint8 " + f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be " + f"float32, float16, or bfloat16 (got {hidden_states.dtype}).") return False return True @@ -160,6 +150,12 @@ def apply( # min because inv_scale. assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " "currently supported.") + + # Ensure w1_scale and w2_scale are not None before calling view + assert w1_scale is not None and w2_scale is not None, ( + "w1_scale and w2_scale must not " + "be None for FlashInferExperts") + quant_scales = [ a1_gscale, w1_scale.view(torch.int32), diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 996825273fc6..d0c928e2c5b9 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) +from vllm.utils.flashinfer import fp4_swizzle_blockscale def get_local_sizes(local_tokens): @@ -92,11 +93,10 @@ def prepare( ) if use_dp: topk_weights, topk_ids, a1q, a1q_scale = \ - get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], + get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501 dim=0, sizes=get_local_sizes(local_tokens)) a1_m, a1_n = a1q.shape - from flashinfer import fp4_swizzle_blockscale a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2) return a1q, a1q_scale, None, topk_ids, topk_weights diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 61763cdcd2a9..5f1b7d39510a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Iterable from enum import Enum -from typing import TYPE_CHECKING, Callable, Literal, Optional, overload +from typing import Callable, Literal, Optional, overload import torch import torch.nn.functional as F @@ -34,15 +34,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx - -try: - from flashinfer import fp4_quantize as fp4_quantize - from flashinfer.fused_moe import ( - cutlass_fused_moe as flashinfer_cutlass_fused_moe) -except ImportError: - if not TYPE_CHECKING: - flashinfer_cutlass_fused_moe = None -has_flashinfer = flashinfer_cutlass_fused_moe is not None +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -54,7 +46,7 @@ from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) - if has_flashinfer: + if has_flashinfer_cutlass_fused_moe(): from .flashinfer_cutlass_prepare_finalize import ( FlashInferCutlassMoEPrepareAndFinalize) else: @@ -1493,7 +1485,9 @@ def forward_impl(self, hidden_states: torch.Tensor, assert self.quant_method is not None # Route to the chunked forward path using the FlashInfer Cutlass kernel # only when data parallelism (DP) is enabled. - use_flashinfer_cutlass_kernels = self.dp_size > 1 and self.moe_parallel_config.use_flashinfer_cutlass_kernels + use_flashinfer_cutlass_kernels = ( + self.dp_size > 1 + and self.moe_parallel_config.use_flashinfer_cutlass_kernels) if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels or use_flashinfer_cutlass_kernels): diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index fc802f0b90bf..ccfe99f35750 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -4,7 +4,6 @@ from typing import Optional, Union import torch -from flashinfer import fp4_quantize as fp4_quantize from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -16,6 +15,7 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv +from vllm.utils.flashinfer import fp4_quantize @triton.jit diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 69c3aaf5f757..0fa7d9936add 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union import torch from torch.nn import Module @@ -13,9 +13,7 @@ cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) from vllm.distributed import get_ep_group from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, _valid_flashinfer_fused_moe) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) @@ -37,14 +35,6 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -try: - from flashinfer import fp4_quantize as fp4_quantize - from flashinfer.fused_moe import ( - cutlass_fused_moe as flashinfer_cutlass_fused_moe) -except ImportError: - if not TYPE_CHECKING: - flashinfer_cutlass_fused_moe = None - logger = init_logger(__name__) QUANT_ALGOS = ["FP8", "NVFP4"] @@ -738,8 +728,8 @@ def __init__(self, quant_config: ModelOptNvFp4Config): self.allow_flashinfer_cutlass = True else: logger.warning_once( - "Flashinfer CUTLASS Fused MoE not supported or found on the current platform." - ) + "Flashinfer CUTLASS Fused MoE not supported " + "or found on the current platform.") if not self.cutlass_nvfp4_supported: if is_fp4_marlin_supported(): @@ -774,6 +764,8 @@ def select_experts_impl(self, moe_parallel_config): experts_kwargs["ep_size"] = moe_parallel_config.ep_size experts_kwargs["tp_rank"] = moe_parallel_config.tp_rank experts_kwargs["tp_size"] = moe_parallel_config.tp_size + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + FlashInferExperts) experts = FlashInferExperts(**experts_kwargs) self.fused_experts = mk.FusedMoEModularKernel( FlashInferCutlassMoEPrepareAndFinalize( @@ -783,13 +775,6 @@ def select_experts_impl(self, moe_parallel_config): experts, ) - @property - def load_up_proj_weight_first(self) -> bool: - # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 - if self.allow_flashinfer_cutlass: - return True - return False - # This method update self.fused_experts # only prepare_finalize is not None call select_gemm_impl # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert @@ -802,7 +787,7 @@ def select_gemm_impl(self, prepare_finalize, moe): all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None if self.allow_flashinfer_cutlass: - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 FlashInferExperts) logger.debug("FlashInferExperts %s", moe) experts = FlashInferExperts( @@ -1066,15 +1051,17 @@ def apply( a2_gscale = torch.min(layer.w2_input_scale_quant) if self.allow_flashinfer_cutlass: # TP or DP case - assert _valid_flashinfer_fused_moe( + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + is_valid_flashinfer_cutlass_fused_moe) + assert is_valid_flashinfer_cutlass_fused_moe( x, layer.w13_weight, layer.w2_weight), ( "Flashinfer CUTLASS Fused MoE not applicable!") extra_expert_args = { 'g1_alphas': layer.g1_alphas, 'g2_alphas': layer.g2_alphas, 'out_dtype': x.dtype, - # Avoid confusion with a1_scale and a2_scale whare are batch size - # related. + # Avoid confusion with a1_scale and a2_scale + # where are batch size related. 'a1_gscale': a1_gscale, 'a2_gscale': a2_gscale, } diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py new file mode 100644 index 000000000000..e0523a20e21c --- /dev/null +++ b/vllm/utils/flashinfer.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compatibility wrapper for FlashInfer API changes. + +Users of vLLM should always import **only** these wrappers. +""" +from __future__ import annotations + +import contextlib +import functools +import importlib +from typing import Any, Callable, NoReturn + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@functools.cache +def has_flashinfer() -> bool: + """Return ``True`` if FlashInfer is available.""" + try: + import flashinfer # noqa: F401 + return True + except ImportError: + return False + + +def _missing(*_: Any, **__: Any) -> NoReturn: + """Placeholder for unavailable FlashInfer backend.""" + raise RuntimeError( + "FlashInfer backend is not available. Please install the package " + "to enable FlashInfer kernels: " + "https://github.com/flashinfer-ai/flashinfer") + + +def _get_submodule(module_name: str) -> Any | None: + """Safely import a submodule and return it, or None if not available.""" + try: + return importlib.import_module(module_name) + except (ImportError, ModuleNotFoundError): + return None + + +# Initialize FlashInfer components +if not has_flashinfer(): + _cutlass_fused_moe_impl: Callable[..., Any] | None = None + _fp4_quantize_impl: Callable[..., Any] | None = None + _fp4_swizzle_blockscale_impl: Callable[..., Any] | None = None + _autotune_impl: Callable[..., Any] | None = None +else: + # Import main flashinfer module + _fi = importlib.import_module("flashinfer") # type: ignore + + # Import fused_moe submodule + _fused_moe_mod = _get_submodule("flashinfer.fused_moe") + _cutlass_fused_moe_impl = getattr(_fused_moe_mod, "cutlass_fused_moe", + None) if _fused_moe_mod else None + + # Import fp4_quant functions + _fp4_quantize_impl = getattr(_fi, "fp4_quantize", None) if _fi else None + _fp4_swizzle_blockscale_impl = getattr(_fi, "fp4_swizzle_blockscale", + None) if _fi else None + + # Import autotuner submodule + _autotuner_mod = _get_submodule("flashinfer.autotuner") + _autotune_impl = getattr(_autotuner_mod, "autotune", + None) if _autotuner_mod else None + + +@functools.cache +def has_flashinfer_cutlass_fused_moe() -> bool: + """Return ``True`` if FlashInfer CUTLASS fused MoE is available.""" + return all([ + _cutlass_fused_moe_impl, + _fp4_quantize_impl, + _fp4_swizzle_blockscale_impl, + ]) + + +def flashinfer_cutlass_fused_moe(*args, **kwargs): + """FlashInfer CUTLASS fused MoE kernel.""" + if _cutlass_fused_moe_impl is None: + return _missing(*args, **kwargs) + return _cutlass_fused_moe_impl(*args, **kwargs) + + +def fp4_quantize(*args, **kwargs): + """FlashInfer FP4 quantization.""" + if _fp4_quantize_impl is None: + return _missing(*args, **kwargs) + return _fp4_quantize_impl(*args, **kwargs) + + +def fp4_swizzle_blockscale(*args, **kwargs): + """FlashInfer FP4 swizzle blockscale.""" + if _fp4_swizzle_blockscale_impl is None: + return _missing(*args, **kwargs) + return _fp4_swizzle_blockscale_impl(*args, **kwargs) + + +def autotune(*args, **kwargs): + """FlashInfer autotuner.""" + if _autotune_impl is None: + # return a null context since autotune is a context manager + return contextlib.nullcontext() + return _autotune_impl(*args, **kwargs) + + +__all__ = [ + "has_flashinfer", + "has_flashinfer_cutlass_fused_moe", + "flashinfer_cutlass_fused_moe", + "fp4_quantize", + "fp4_swizzle_blockscale", + "autotune", +] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9c1f1750c252..af216539c900 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2021,12 +2021,12 @@ def _dummy_run( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - from flashinfer import AutoTuner, autotune + with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp), autotune(): + num_tokens_across_dp=num_tokens_across_dp): outputs = model( input_ids=input_ids, positions=positions, From 841fcd1f3fefe5231a26dc9a2caced22dcdd9db9 Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 16 Jul 2025 16:21:05 +0000 Subject: [PATCH 22/30] Move switch w13 to modelopt --- vllm/_custom_ops.py | 3 +- .../layers/fused_moe/cutlass_moe.py | 22 ++++---- .../flashinfer_cutlass_prepare_finalize.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 14 ++--- vllm/model_executor/layers/fused_moe/utils.py | 4 +- .../layers/quantization/modelopt.py | 52 ++++++++++++------- 6 files changed, 54 insertions(+), 43 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index eaad7ac135eb..c81e64e7b05c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -961,8 +961,7 @@ def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, b_tensors: torch.Tensor, a_scales: torch.Tensor, b_scales: torch.Tensor, alphas: torch.Tensor, problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, - out_dtype: torch.dtype, device: torch.device): + expert_offsets: torch.Tensor, sf_offsets: torch.Tensor): """ An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs the gemms for each combination based on the specified problem sizes. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index f53b1d4c788f..e0114dfb69bd 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -523,8 +523,7 @@ def run_cutlass_moe_fp4( c3 = _resize_cache(workspace13, (m * topk, k)) ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale, w1_blockscale, w1_alphas, problem_sizes1, - expert_offsets[:-1], blockscale_offsets[:-1], - out_dtype, device) + expert_offsets[:-1], blockscale_offsets[:-1]) del rep_a_fp4, rep_a_blockscale torch.ops._C.silu_and_mul(c2, c1) int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( @@ -532,15 +531,18 @@ def run_cutlass_moe_fp4( ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale, w2_alphas, problem_sizes2, expert_offsets[:-1], - blockscale_offsets[:-1], out_dtype, device) + blockscale_offsets[:-1]) del int_fp4, int_blockscale c3 = ops.shuffle_rows(c3, c_map) assert output.dtype == out_dtype - output.copy_((c3.view(m, num_topk, k) * - topk_weights.view(m, num_topk, 1).half()).sum(dim=1), - non_blocking=True) + if not apply_router_weight_on_input: + output.copy_((c3.view(m, num_topk, k) * + topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1), + non_blocking=True) + else: + output.copy_(c2.view(m, num_topk, k).sum(dim=1), non_blocking=True) return @@ -706,7 +708,6 @@ def cutlass_moe_fp4( ), ) extra_expert_args = { - # 'topk_weights': topk_weights, 'g1_alphas': g1_alphas, 'g2_alphas': g2_alphas, 'a1_gscale': a1_gscale, @@ -718,9 +719,10 @@ def cutlass_moe_fp4( 'device': device, } - # NVFP4 requires two levels of quantization, which involves computing some scaling - # factors dynamically. This makes it incompatible with the typical - # prepare -> MoE -> finalize pipeline. Move the quantization logic into the MoE body. + # NVFP4 requires two levels of quantization, which involves computing some + # scaling factors dynamically. This makes it incompatible with the typical + # prepare -> MoE -> finalize pipeline. Move the quantization logic into the + # MoE body. extra_prepare_args = { 'skip_quant': True, } diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index d0c928e2c5b9..611dcbe4555d 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -88,7 +88,7 @@ def prepare( quant_config.quant_dtype, self.per_channel_quant, self.block_shape, - is_fp4_scalar_swizzled= + is_fp4_scale_swizzled= not use_dp, # Needs swizzling after communication ) if use_dp: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5f1b7d39510a..19d7943ab8e8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -931,18 +931,12 @@ def _load_w13(self, shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) # w3, up_proj: Load into second logical weight of w13. - # The FlashInfer Cutlass fused MoE kernel expects the combined weights - # to be ordered as [w3, w1], unlike the standard [w1, w3] layout. - assert shard_id in ("w1", "w3") - switch_w13 = getattr(self.quant_method, 'load_up_proj_weight_first', - False) - if (switch_w13 and shard_id == "w1") or (not switch_w13 - and shard_id == "w3"): - start = shard_size else: - start = 0 - expert_data = expert_data.narrow(shard_dim, start, shard_size) + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight) def _load_w2(self, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index ccfe99f35750..61841b9faa7b 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -183,7 +183,7 @@ def moe_kernel_quantize_input( quant_dtype: Union[None, torch.dtype, str], per_act_token_quant: bool, block_shape: Optional[list[int]] = None, - is_fp4_scalar_swizzled: bool = True, + is_fp4_scale_swizzled: bool = True, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if quant_dtype == torch.float8_e4m3fn: return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) @@ -192,7 +192,7 @@ def moe_kernel_quantize_input( elif quant_dtype == torch.uint8: # nvfp4 return _fp4_quantize(A, A_scale, - is_sf_swizzled_layout=is_fp4_scalar_swizzled) + is_sf_swizzled_layout=is_fp4_scale_swizzled) elif quant_dtype == "mxfp4": return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) else: diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0fa7d9936add..2daba8b232ef 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -743,23 +743,18 @@ def __init__(self, quant_config: ModelOptNvFp4Config): self.fused_experts = cutlass_moe_fp4 - @property - def load_up_proj_weight_first(self) -> bool: - # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 - return self.allow_flashinfer_cutlass - - def select_experts_impl(self, moe_parallel_config): + def select_experts_impl( + self, moe_parallel_config) -> mk.FusedMoEPermuteExpertsUnpermute: if not self.allow_flashinfer_cutlass: return - logger.debug("FlashInferExperts") + logger.debug_once("FlashInferExperts") # default to TP/EP case only experts_kwargs = { "use_nvfp4_w4a4": True, "use_dp": moe_parallel_config.dp_size > 1, } - # if not moe_parallel_config.dp_size > 1 and moe_parallel_config.use_ep: experts_kwargs["ep_rank"] = moe_parallel_config.ep_rank experts_kwargs["ep_size"] = moe_parallel_config.ep_size experts_kwargs["tp_rank"] = moe_parallel_config.tp_rank @@ -779,7 +774,8 @@ def select_experts_impl(self, moe_parallel_config): # only prepare_finalize is not None call select_gemm_impl # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert # when it's not called(TP case), we still have 2 kernels to use. - def select_gemm_impl(self, prepare_finalize, moe): + def select_gemm_impl(self, prepare_finalize, + moe) -> mk.FusedMoEPermuteExpertsUnpermute: assert moe is not None assert prepare_finalize is not None @@ -789,7 +785,7 @@ def select_gemm_impl(self, prepare_finalize, moe): if self.allow_flashinfer_cutlass: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 FlashInferExperts) - logger.debug("FlashInferExperts %s", moe) + logger.debug_once("Using FlashInferExperts") experts = FlashInferExperts( use_nvfp4_w4a4=True, use_dp=moe.moe_parallel_config.dp_size > 1, @@ -800,11 +796,12 @@ def select_gemm_impl(self, prepare_finalize, moe): ) else: assert moe.dp_size > 1 - logger.debug("CutlassExpertsFp4 %s", moe) - # current doesn't support DP + logger.debug_once("Using CutlassExpertsFp4") + # Currently CutlassExpertsFp4 doesn't support DP raise ValueError( - "CutlassExpertsFp4 Doesn't support DP. " - "Use flashinfer CUTLASS FusedMoE backend instead.") + "CutlassExpertsFp4 doesn't support DP. " + "Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)" + " backend instead.") return experts @@ -928,8 +925,30 @@ def swizzle_blockscale(self, scale: torch.tensor): if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # GEMM 1 + # The FlashInfer Cutlass fused MoE kernel expects the combined weights + # to be ordered as [w3, w1], unlike the standard [w1, w3] layout. + gemm1_weight = layer.w13_weight.data + gemm1_weight_scale = layer.w13_weight_scale.data + + if self.allow_flashinfer_cutlass: + dim = -2 + size = gemm1_weight.size(dim) + assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}" + half = size // 2 + + # Reorder weight + w1, w3 = gemm1_weight.split(half, dim=dim) + gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous() + + # Reorder scale + s1, s3 = gemm1_weight_scale.split(half, dim=dim) + gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous() + + layer.w13_weight = Parameter(gemm1_weight, requires_grad=False) + layer.w13_weight_scale = Parameter(gemm1_weight_scale, + requires_grad=False) + if not torch.allclose(layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]): logger.warning_once( @@ -960,9 +979,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_input_scale_quant = Parameter( (1 / w13_input_scale).to(torch.float32), requires_grad=False) - layer.w13_weight = Parameter(layer.w13_weight.data, - requires_grad=False) - # GEMM 2 layer.g2_alphas = Parameter( (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), From 716621aece8ff460f41d7458224f6e387006d793 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 17 Jul 2025 03:08:30 +0000 Subject: [PATCH 23/30] Address comments. Signed-off-by: shuw --- .../layers/fused_moe/cutlass_moe.py | 51 +++++++++-------- .../fused_moe/flashinfer_cutlass_moe.py | 4 +- .../layers/fused_moe/modular_kernel.py | 9 ++- .../layers/quantization/modelopt.py | 57 ++++++++++--------- 4 files changed, 65 insertions(+), 56 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index e0114dfb69bd..7984fb767ef3 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -538,9 +538,10 @@ def run_cutlass_moe_fp4( assert output.dtype == out_dtype if not apply_router_weight_on_input: - output.copy_((c3.view(m, num_topk, k) * - topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1), - non_blocking=True) + output.copy_( + (c3.view(m, num_topk, k) * + topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1), + non_blocking=True) else: output.copy_(c2.view(m, num_topk, k).sum(dim=1), non_blocking=True) return @@ -652,26 +653,26 @@ def apply( "ModelOptNvFp4FusedMoE.") run_cutlass_moe_fp4( - output, - hidden_states, - a1_gscale, - w1, - w1_scale, - g1_alphas, - a2_gscale, - w2, - w2_scale, - g2_alphas, - topk_weights, - topk_ids, - workspace13, - workspace2, - m, - n, - k, - e, - device, - apply_router_weight_on_input, + output=output, + a=hidden_states, + a1_gscale=a1_gscale, + w1_fp4=w1, + w1_blockscale=w1_scale, + w1_alphas=g1_alphas, + a2_gscale=a2_gscale, + w2_fp4=w2, + w2_blockscale=w2_scale, + w2_alphas=g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + workspace13=workspace13, + workspace2=workspace2, + m=m, + n=n, + k=k, + e=e, + device=device, + apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -719,9 +720,9 @@ def cutlass_moe_fp4( 'device': device, } - # NVFP4 requires two levels of quantization, which involves computing some + # NVFP4 requires two levels of quantization, which involves computing some # scaling factors dynamically. This makes it incompatible with the typical - # prepare -> MoE -> finalize pipeline. Move the quantization logic into the + # prepare -> MoE -> finalize pipeline. Move the quantization logic into the # MoE body. extra_prepare_args = { 'skip_quant': True, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 3b236aeb8fae..f7afb59361a2 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -139,7 +139,7 @@ def apply( workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: Optional[bool], # Not used + apply_router_weight_on_input: Optional[bool], g1_alphas: torch.Tensor, g2_alphas: torch.Tensor, a1_gscale: torch.Tensor, @@ -156,6 +156,8 @@ def apply( "w1_scale and w2_scale must not " "be None for FlashInferExperts") + assert apply_router_weight_on_input is False + quant_scales = [ a1_gscale, w1_scale.view(torch.int32), diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a0e444fd171e..3148a4d5e0b5 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -725,9 +725,12 @@ def forward( - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. - - extra_expert_args (Optional[dict]): Extra keyword arguments to pass to fused_experts.apply. - - extra_prepare_args (Optional[dict]): Extra keyword arguments to pass to prepare. - - extra_finalize_args (Optional[dict]): Extra keyword arguments to pass to finalize. + - extra_expert_args (Optional[dict]): Extra keyword arguments to pass to + fused_experts.apply. + - extra_prepare_args (Optional[dict]): Extra keyword arguments to pass + to prepare. + - extra_finalize_args (Optional[dict]): Extra keyword arguments to pass + to finalize. Returns: - torch.Tensor: The output tensor after applying the MoE layer. diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 2daba8b232ef..ac841d223787 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -738,10 +738,8 @@ def __init__(self, quant_config: ModelOptNvFp4Config): raise ValueError("Current platform does not support NVFP4" " quantization. Please use Blackwell and" " above.") - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4) - self.fused_experts = cutlass_moe_fp4 + self.fused_experts = None # type: ignore def select_experts_impl( self, moe_parallel_config) -> mk.FusedMoEPermuteExpertsUnpermute: @@ -1063,15 +1061,40 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map) - a1_gscale = torch.min(layer.w13_input_scale_quant) - a2_gscale = torch.min(layer.w2_input_scale_quant) - if self.allow_flashinfer_cutlass: + if self.fused_experts is None: + # If no modular kernel is provided, use cutlass_moe_fp4 for TP case + # only (no EP). + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + out = cutlass_moe_fp4( + a=x, + w1_fp4=layer.w13_weight, + w2_fp4=layer.w2_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w2_blockscale=layer.w2_blockscale_swizzled, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + device=x.device, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input) + else: # TP or DP case from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 is_valid_flashinfer_cutlass_fused_moe) assert is_valid_flashinfer_cutlass_fused_moe( x, layer.w13_weight, layer.w2_weight), ( "Flashinfer CUTLASS Fused MoE not applicable!") + + a1_gscale = torch.min(layer.w13_input_scale_quant) + a2_gscale = torch.min(layer.w2_input_scale_quant) extra_expert_args = { 'g1_alphas': layer.g1_alphas, 'g2_alphas': layer.g2_alphas, @@ -1103,29 +1126,9 @@ def apply( expert_map=expert_map, w1_scale=layer.w13_blockscale_swizzled, w2_scale=layer.w2_blockscale_swizzled, + apply_router_weight_on_input=apply_router_weight_on_input, extra_expert_args=extra_expert_args, extra_prepare_args=extra_prepare_args, extra_finalize_args=extra_finalize_args, ) - else: - # cutlass_moe_fp4, TP case only(no EP) - out = self.fused_experts( - a=x, - w1_fp4=layer.w13_weight, - w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w2_blockscale=layer.w2_blockscale_swizzled, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - topk_weights=topk_weights, - topk_ids=topk_ids, - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - device=x.device, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) return out From f31be5a8df76e9da2d9f53643c81b357469e3a55 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 17 Jul 2025 05:25:50 +0000 Subject: [PATCH 24/30] Upd Signed-off-by: shuw --- .../layers/fused_moe/cutlass_moe.py | 65 +++++++++---------- .../fused_moe/flashinfer_cutlass_moe.py | 19 ++++-- .../flashinfer_cutlass_prepare_finalize.py | 42 ++++++------ vllm/model_executor/layers/fused_moe/layer.py | 13 ++-- .../layers/fused_moe/modular_kernel.py | 36 +++++----- vllm/model_executor/layers/fused_moe/utils.py | 2 +- .../compressed_tensors_moe.py | 10 +-- .../layers/quantization/modelopt.py | 7 +- 8 files changed, 101 insertions(+), 93 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 7984fb767ef3..d38674454295 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -616,41 +616,36 @@ def workspace_shapes( return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: torch.Tensor, - workspace13: Optional[torch.Tensor], - workspace2: Optional[torch.Tensor], - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - # start of extra_expert_args - topk_weights: torch.Tensor, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, - m: int, - n: int, - k: int, - e: int, - device: torch.device, - apply_router_weight_on_input, - ): - assert expert_map is None, ("Expert Parallelism / expert_map " - "is currently not supported for " - "ModelOptNvFp4FusedMoE.") + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor, + w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor], + workspace2: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict]): + assert 'g1_alphas' in extra_expert_args + assert 'g2_alphas' in extra_expert_args + assert 'a1_gscale' in extra_expert_args + assert 'a2_gscale' in extra_expert_args + assert 'm' in extra_expert_args + assert 'n' in extra_expert_args + assert 'k' in extra_expert_args + assert 'e' in extra_expert_args + assert 'device' in extra_expert_args + + g1_alphas = extra_expert_args['g1_alphas'] + g2_alphas = extra_expert_args['g2_alphas'] + a1_gscale = extra_expert_args['a1_gscale'] + a2_gscale = extra_expert_args['a2_gscale'] + m = extra_expert_args['m'] + n = extra_expert_args['n'] + k = extra_expert_args['k'] + e = extra_expert_args['e'] + device = extra_expert_args['device'] run_cutlass_moe_fp4( output=output, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index f7afb59361a2..b0d1b06c4324 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -140,12 +140,19 @@ def apply( workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: Optional[bool], - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, - out_dtype: torch.dtype, + extra_expert_args: Optional[dict], ): + assert 'g1_alphas' in extra_expert_args + assert 'g2_alphas' in extra_expert_args + assert 'a1_gscale' in extra_expert_args + assert 'a2_gscale' in extra_expert_args + assert 'out_dtype' in extra_expert_args + g1_alphas: torch.Tensor = extra_expert_args['g1_alphas'] + g2_alphas: torch.Tensor = extra_expert_args['g2_alphas'] + a1_gscale: torch.Tensor = extra_expert_args['a1_gscale'] + a2_gscale: torch.Tensor = extra_expert_args['a2_gscale'] + out_dtype: torch.dtype = extra_expert_args['out_dtype'] + # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " @@ -156,7 +163,7 @@ def apply( "w1_scale and w2_scale must not " "be None for FlashInferExperts") - assert apply_router_weight_on_input is False + assert not apply_router_weight_on_input quant_scales = [ a1_gscale, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 611dcbe4555d..eec0bac7de2e 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -69,18 +69,18 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - a1_gscale: torch.Tensor, - use_dp: Optional[bool] = True, - local_tokens: int = -1, + extra_prepare_args: Optional[dict] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - if apply_router_weight_on_input: - topk = topk_ids.size(1) - # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, \ - "apply_router_weight_on_input is only implemented for topk=1" - a1.mul_(topk_weights.to(a1.dtype)) + assert not apply_router_weight_on_input + assert 'a1_gscale' in extra_prepare_args + assert 'use_dp' in extra_prepare_args + assert 'local_tokens' in extra_prepare_args + + a1_gscale = extra_prepare_args['a1_gscale'] + use_dp = extra_prepare_args['use_dp'] + local_tokens = extra_prepare_args['local_tokens'] a1q, a1q_scale = moe_kernel_quantize_input( a1, @@ -101,17 +101,19 @@ def prepare( return a1q, a1q_scale, None, topk_ids, topk_weights - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - use_dp: bool = False, - local_tokens: int = -1, - ) -> None: + def finalize(self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict] = None) -> None: + assert 'use_dp' in extra_finalize_args + assert 'local_tokens' in extra_finalize_args + use_dp = extra_finalize_args['use_dp'] + local_tokens = extra_finalize_args['local_tokens'] + if use_dp: fused_expert_output = get_dp_group().reduce_scatterv( fused_expert_output, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 19d7943ab8e8..8701d2583984 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -79,9 +79,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): moe: FusedMoEConfig - def select_experts_impl(self, moe_parallel_config): - pass - @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -214,6 +211,14 @@ def select_gemm_impl( f"{self.__class__.__name__} must select appropriate gemm " "implementation based on the prepare_finalize") + def maybe_swap_experts_impl( + self, + moe_parallel_config: FusedMoEParallelConfig, + ): + raise NotImplementedError( + f"{self.__class__.__name__} must select appropriate experts " + "implementation based on the moe_parallel_config") + @abstractmethod def apply( self, @@ -756,7 +761,7 @@ def __init__( quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None else quant_config.get_quant_method(self, prefix)) - quant_method.select_experts_impl(self.moe_parallel_config) + quant_method.maybe_swap_experts_impl(self.moe_parallel_config) assert quant_method is not None assert isinstance(quant_method, FusedMoEMethodBase) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 3148a4d5e0b5..6af2b9645011 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -160,6 +160,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -198,6 +199,7 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: TopKWeightAndReduce, + extra_finalize_args: Optional[dict] = None, ) -> None: """ Perform any combine plus apply weights and perform a reduction on the @@ -375,6 +377,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict], ): """ This function computes the intermediate result of a Mixture of Experts @@ -480,7 +483,7 @@ def _do_fused_experts( a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_kwargs: Optional[dict] = None) -> torch.Tensor: + extra_expert_args: Optional[dict] = None) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -523,7 +526,7 @@ def _do_fused_experts( workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - **extra_expert_kwargs) + extra_expert_args=extra_expert_args) return fused_out @@ -547,7 +550,7 @@ def _maybe_chunk_fused_experts( a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_kwargs: Optional[dict] = None, + extra_expert_args: Optional[dict] = None, ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -576,7 +579,7 @@ def _maybe_chunk_fused_experts( a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_kwargs=extra_expert_kwargs) + extra_expert_args=extra_expert_args) # Chunking required case assert num_chunks > 1 @@ -630,9 +633,11 @@ def slice_expert_tokens_metadata( expert_num_tokens=c_expert_num_tokens, expert_num_tokens_cpu=c_expert_num_tokens_cpu) - m = extra_expert_kwargs.get('m') + m = None + if extra_expert_args is not None and 'm' in extra_expert_args: + m = extra_expert_args.get('m') - chunked_extra_expert_kwargs = extra_expert_kwargs + chunked_extra_expert_args = extra_expert_args for chunk_idx in range(num_chunks): c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( slice_input_tensors(chunk_idx)) @@ -647,7 +652,7 @@ def slice_expert_tokens_metadata( e = min(s + CHUNK_SIZE, M) if m is not None: - chunked_extra_expert_kwargs['m'] = e - s + chunked_extra_expert_args['m'] = e - s self._do_fused_experts( fused_out=slice_output_tensor(chunk_idx), a1=a1, @@ -668,7 +673,7 @@ def slice_expert_tokens_metadata( a2_scale=c_a2_scale, expert_tokens_meta=c_expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_kwargs=chunked_extra_expert_kwargs) + extra_expert_args=chunked_extra_expert_args) return fused_out @@ -743,8 +748,6 @@ def forward( if global_num_experts == -1: global_num_experts = local_num_experts - extra_prepare_kwargs = extra_prepare_args or {} - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( a1, @@ -756,7 +759,7 @@ def forward( expert_map, apply_router_weight_on_input, self.fused_experts.quant_config, - **extra_prepare_kwargs, + extra_prepare_args, ) # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. @@ -765,12 +768,6 @@ def forward( _expert_topk_weights) fused_out = None - extra_expert_kwargs = extra_expert_args or {} - - if 'topk_weights' in extra_expert_kwargs and extra_expert_kwargs[ - 'topk_weights'] is None: - extra_expert_kwargs['topk_weights'] = topk_weights - assert extra_expert_kwargs['topk_weights'] is not None if a1q.numel() == 0: # This happens when none of the tokens from the all2all reach this @@ -800,13 +797,12 @@ def forward( a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - extra_expert_kwargs=extra_expert_kwargs) + extra_expert_args=extra_expert_args) - extra_finalize_kwargs = extra_finalize_args or {} self.prepare_finalize.finalize( output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input, self.fused_experts.finalize_weight_and_reduce_impl(), - **extra_finalize_kwargs) + extra_finalize_args) return output diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 61841b9faa7b..aca733b16d1c 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -103,7 +103,7 @@ def _fp4_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], is_sf_swizzled_layout: bool, -) -> tuple[torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: return fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index baf4fec3cc68..21147a223671 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -339,19 +339,19 @@ def apply( return cutlass_moe_fp4( a=x, w1_fp4=layer.w13_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w1_alphas=layer.g1_alphas, w2_fp4=layer.w2_weight, + w1_blockscale=layer.w13_blockscale_swizzled, w2_blockscale=layer.w2_blockscale_swizzled, - w2_alphas=layer.g2_alphas, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, topk_weights=topk_weights, topk_ids=topk_ids, m=x.shape[0], n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, device=x.device, apply_router_weight_on_input=apply_router_weight_on_input).to( x.dtype) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index ac841d223787..ba2b9ca157aa 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -13,6 +13,7 @@ cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) from vllm.distributed import get_ep_group from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.layer import ( @@ -741,8 +742,10 @@ def __init__(self, quant_config: ModelOptNvFp4Config): self.fused_experts = None # type: ignore - def select_experts_impl( - self, moe_parallel_config) -> mk.FusedMoEPermuteExpertsUnpermute: + def maybe_swap_experts_impl( + self, + moe_parallel_config: FusedMoEParallelConfig, + ): if not self.allow_flashinfer_cutlass: return From 6225670538d6da1b2c84c0d335b8cd2d79953af5 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 17 Jul 2025 05:41:58 +0000 Subject: [PATCH 25/30] Upd Signed-off-by: shuw --- .../layers/fused_moe/cutlass_moe.py | 3 +-- .../layers/fused_moe/flashinfer_cutlass_moe.py | 2 +- .../flashinfer_cutlass_prepare_finalize.py | 11 ++++------- .../layers/quantization/modelopt.py | 15 ++++++++++----- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d38674454295..274a522a9e94 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -625,8 +625,7 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict]): + apply_router_weight_on_input: bool, extra_expert_args: dict): assert 'g1_alphas' in extra_expert_args assert 'g2_alphas' in extra_expert_args assert 'a1_gscale' in extra_expert_args diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index b0d1b06c4324..4fffbfa51a42 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -140,7 +140,7 @@ def apply( workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: Optional[bool], - extra_expert_args: Optional[dict], + extra_expert_args: dict, ): assert 'g1_alphas' in extra_expert_args assert 'g2_alphas' in extra_expert_args diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index eec0bac7de2e..10e63889591a 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -69,7 +69,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict] = None, + extra_prepare_args: dict, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -101,14 +101,11 @@ def prepare( return a1q, a1q_scale, None, topk_ids, topk_weights - def finalize(self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: Optional[dict] = None) -> None: + extra_finalize_args: dict) -> None: assert 'use_dp' in extra_finalize_args assert 'local_tokens' in extra_finalize_args use_dp = extra_finalize_args['use_dp'] diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index ba2b9ca157aa..02100a460e48 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -752,14 +752,19 @@ def maybe_swap_experts_impl( logger.debug_once("FlashInferExperts") # default to TP/EP case only - experts_kwargs = { + experts_kwargs: dict[str, Any] = { "use_nvfp4_w4a4": True, "use_dp": moe_parallel_config.dp_size > 1, } - experts_kwargs["ep_rank"] = moe_parallel_config.ep_rank - experts_kwargs["ep_size"] = moe_parallel_config.ep_size - experts_kwargs["tp_rank"] = moe_parallel_config.tp_rank - experts_kwargs["tp_size"] = moe_parallel_config.tp_size + experts_kwargs: dict[str, Any] = { + "use_nvfp4_w4a4": True, + "use_dp": moe_parallel_config.dp_size > 1, + "ep_rank": moe_parallel_config.ep_rank, + "ep_size": moe_parallel_config.ep_size, + "tp_rank": moe_parallel_config.tp_rank, + "tp_size": moe_parallel_config.tp_size, + } + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 FlashInferExperts) experts = FlashInferExperts(**experts_kwargs) From a08b47f7d971fa02bd9f04835117906c103a3274 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 17 Jul 2025 16:15:40 +0000 Subject: [PATCH 26/30] fix interface changes Signed-off-by: shuw --- .../layers/fused_moe/batched_deep_gemm_moe.py | 36 ++++------ .../layers/fused_moe/cutlass_moe.py | 37 ++++------- .../fused_moe/deepep_ht_prepare_finalize.py | 19 +++--- .../fused_moe/deepep_ll_prepare_finalize.py | 19 +++--- .../fused_moe/flashinfer_cutlass_moe.py | 23 ++++--- .../flashinfer_cutlass_prepare_finalize.py | 26 +++----- .../layers/fused_moe/fused_moe.py | 1 + vllm/model_executor/layers/fused_moe/layer.py | 4 +- .../layers/fused_moe/modular_kernel.py | 65 +++++++------------ .../layers/fused_moe/pplx_prepare_finalize.py | 30 ++++----- .../layers/fused_moe/prepare_finalize.py | 26 ++++---- vllm/model_executor/layers/fused_moe/utils.py | 16 ++++- .../layers/quantization/modelopt.py | 4 -- 13 files changed, 129 insertions(+), 177 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 0b3943292152..b77ec15e67bc 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch @@ -254,28 +254,18 @@ def workspace_shapes( output = (num_experts, max_num_tokens * num_dispatchers, K) return (workspace13, workspace2, output, a.dtype) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 274a522a9e94..c979d8626548 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ CUTLASS based Fused MoE kernels.""" -from typing import Callable, Optional +from typing import Any, Callable, Optional import torch @@ -15,7 +15,8 @@ TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, _fp8_quantize, - _resize_cache) + _resize_cache, + extract_required_args) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -301,7 +302,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool): + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" @@ -625,27 +627,14 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, extra_expert_args: dict): - assert 'g1_alphas' in extra_expert_args - assert 'g2_alphas' in extra_expert_args - assert 'a1_gscale' in extra_expert_args - assert 'a2_gscale' in extra_expert_args - assert 'm' in extra_expert_args - assert 'n' in extra_expert_args - assert 'k' in extra_expert_args - assert 'e' in extra_expert_args - assert 'device' in extra_expert_args - - g1_alphas = extra_expert_args['g1_alphas'] - g2_alphas = extra_expert_args['g2_alphas'] - a1_gscale = extra_expert_args['a1_gscale'] - a2_gscale = extra_expert_args['a2_gscale'] - m = extra_expert_args['m'] - n = extra_expert_args['n'] - k = extra_expert_args['k'] - e = extra_expert_args['e'] - device = extra_expert_args['device'] - + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): + required_keys = [ + "g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k", + "e", "device" + ] + (g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e, + device) = extract_required_args(extra_expert_args, required_keys) run_cutlass_moe_fp4( output=output, a=hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index e10927c4dce5..7016ff34c3a8 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import deep_ep import torch @@ -127,16 +127,12 @@ def _do_dispatch(self, tokens: torch.Tensor, expert_topk_weights) def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -191,7 +187,8 @@ def prepare( def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: assert self.handle is not None diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index b04f01975849..57871ca250ae 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Any, Optional, Union import deep_ep import torch @@ -111,16 +111,12 @@ def _do_quant( return x, x_scales def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -169,7 +165,8 @@ def prepare( def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 4fffbfa51a42..ffb915a44c63 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch @@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) +from vllm.model_executor.layers.fused_moe.utils import extract_required_args from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, has_flashinfer_cutlass_fused_moe) @@ -140,18 +141,16 @@ def apply( workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: Optional[bool], - extra_expert_args: dict, + extra_expert_args: Optional[dict[str, Any]], ): - assert 'g1_alphas' in extra_expert_args - assert 'g2_alphas' in extra_expert_args - assert 'a1_gscale' in extra_expert_args - assert 'a2_gscale' in extra_expert_args - assert 'out_dtype' in extra_expert_args - g1_alphas: torch.Tensor = extra_expert_args['g1_alphas'] - g2_alphas: torch.Tensor = extra_expert_args['g2_alphas'] - a1_gscale: torch.Tensor = extra_expert_args['a1_gscale'] - a2_gscale: torch.Tensor = extra_expert_args['a2_gscale'] - out_dtype: torch.dtype = extra_expert_args['out_dtype'] + assert extra_expert_args is not None, \ + "extra_expert_args must be provided" + required_keys = [ + 'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype' + ] + + g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = ( + extract_required_args(extra_expert_args, required_keys)) # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 10e63889591a..49819504c8ec 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch @@ -10,7 +10,7 @@ from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + extract_required_args, moe_kernel_quantize_input) from vllm.utils.flashinfer import fp4_swizzle_blockscale @@ -69,18 +69,14 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: dict, + extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: assert not apply_router_weight_on_input - assert 'a1_gscale' in extra_prepare_args - assert 'use_dp' in extra_prepare_args - assert 'local_tokens' in extra_prepare_args - a1_gscale = extra_prepare_args['a1_gscale'] - use_dp = extra_prepare_args['use_dp'] - local_tokens = extra_prepare_args['local_tokens'] + (a1_gscale, use_dp, local_tokens) = extract_required_args( + extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens']) a1q, a1q_scale = moe_kernel_quantize_input( a1, @@ -88,8 +84,7 @@ def prepare( quant_config.quant_dtype, self.per_channel_quant, self.block_shape, - is_fp4_scale_swizzled= - not use_dp, # Needs swizzling after communication + is_fp4_scale_swizzled=not use_dp, # Swizzling after communication ) if use_dp: topk_weights, topk_ids, a1q, a1q_scale = \ @@ -105,12 +100,11 @@ def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - extra_finalize_args: dict) -> None: - assert 'use_dp' in extra_finalize_args - assert 'local_tokens' in extra_finalize_args - use_dp = extra_finalize_args['use_dp'] - local_tokens = extra_finalize_args['local_tokens'] + extra_finalize_args: Optional[dict[str, Any]]) -> None: + (use_dp, + local_tokens) = extract_required_args(extra_finalize_args, + ['use_dp', 'local_tokens']) if use_dp: fused_expert_output = get_dp_group().reduce_scatterv( fused_expert_output, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f0bffc7dae27..f10f4dcdb035 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1645,6 +1645,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], ): # Check constraints. if self.use_int4_w4a16: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8701d2583984..4ccb943b8f40 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -761,8 +761,6 @@ def __init__( quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None else quant_config.get_quant_method(self, prefix)) - quant_method.maybe_swap_experts_impl(self.moe_parallel_config) - assert quant_method is not None assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method @@ -797,6 +795,8 @@ def __init__( moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) + if isinstance(self.quant_method, FusedMoEMethodBase): + self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config) # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 6af2b9645011..077bbbce760e 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from enum import Enum from math import prod -from typing import Optional, final +from typing import Any, Optional, final import torch @@ -150,17 +150,12 @@ class FusedMoEPrepareAndFinalize(ABC): @abstractmethod def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - extra_prepare_args: Optional[dict] = None, + extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -191,16 +186,11 @@ def prepare( raise NotImplementedError @abstractmethod - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: TopKWeightAndReduce, - extra_finalize_args: Optional[dict] = None, - ) -> None: + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: """ Perform any combine plus apply weights and perform a reduction on the fused experts output. @@ -377,7 +367,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict], + extra_expert_args: Optional[dict[str, Any]], ): """ This function computes the intermediate result of a Mixture of Experts @@ -463,27 +453,18 @@ def __init__( f"{fused_experts.activation_formats[0]}") def _do_fused_experts( - self, - fused_out: Optional[torch.Tensor], - a1: torch.Tensor, - a1q: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - local_num_experts: int, + self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, + a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str, global_num_experts: int, local_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict] = None) -> torch.Tensor: + extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -550,7 +531,7 @@ def _maybe_chunk_fused_experts( a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - extra_expert_args: Optional[dict] = None, + extra_expert_args: Optional[dict[str, Any]], ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -637,7 +618,11 @@ def slice_expert_tokens_metadata( if extra_expert_args is not None and 'm' in extra_expert_args: m = extra_expert_args.get('m') - chunked_extra_expert_args = extra_expert_args + if extra_expert_args is not None: + chunked_extra_expert_args = extra_expert_args + else: + chunked_extra_expert_args = {} + for chunk_idx in range(num_chunks): c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( slice_input_tensors(chunk_idx)) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 5a23a9f1ab09..46931f2dd7c7 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import pplx_kernels as pplx import torch @@ -89,16 +89,12 @@ def num_dispatchers(self) -> int: return self.num_dispatchers_ def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -217,15 +213,11 @@ def prepare( return expert_x, expert_x_scale, expert_tokens_meta, None, None - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 178174b5a15f..696c7cdba9a7 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch @@ -38,7 +38,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - skip_quant: Optional[bool] = False, + extra_prepare_args: Optional[dict[str, Any]], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -50,7 +50,9 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - if skip_quant: + if (extra_prepare_args is not None + and extra_prepare_args.get("skip_quant", True)): + # Skip quantization if explicitly requested return a1, None, None, None, None a1q, a1q_scale = moe_kernel_quantize_input( @@ -59,17 +61,13 @@ def prepare( return a1q, a1q_scale, None, None, None - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - skip_weight_reduce: Optional[bool] = False, - ) -> None: - if skip_weight_reduce: + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: + if (extra_finalize_args is not None + and extra_finalize_args.get("skip_weight_reduce", True)): assert output.shape == fused_expert_output.shape output.copy_(fused_expert_output) else: diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index aca733b16d1c..6ac99c94fddc 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from math import prod -from typing import Optional, Union +from typing import Any, Optional, Union import torch @@ -252,3 +252,17 @@ def _validate_scale_shape( assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" + + +def extract_required_args( + extra_args: dict[str, Any], + required_keys: list[str], +) -> tuple[Any, ...]: + if extra_args is None: + raise ValueError("`extra_args` must be provided.") + + missing_keys = [k for k in required_keys if k not in extra_args] + if missing_keys: + raise ValueError(f"Missing keys in `extra_args`: {missing_keys}") + + return tuple(extra_args[k] for k in required_keys) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 02100a460e48..3807899fc3e5 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -752,10 +752,6 @@ def maybe_swap_experts_impl( logger.debug_once("FlashInferExperts") # default to TP/EP case only - experts_kwargs: dict[str, Any] = { - "use_nvfp4_w4a4": True, - "use_dp": moe_parallel_config.dp_size > 1, - } experts_kwargs: dict[str, Any] = { "use_nvfp4_w4a4": True, "use_dp": moe_parallel_config.dp_size > 1, From 24ce9aaaaf177a2e21e57c0e3a28e8c38c3df8fa Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 17 Jul 2025 16:41:56 +0000 Subject: [PATCH 27/30] Upd Signed-off-by: shuw --- .../batched_triton_or_deep_gemm_moe.py | 7 ++-- .../layers/fused_moe/cutlass_moe.py | 1 + .../layers/fused_moe/deep_gemm_moe.py | 3 +- .../fused_moe/flashinfer_cutlass_moe.py | 12 +++++- .../layers/fused_moe/fused_batched_moe.py | 36 ++++++++---------- vllm/model_executor/layers/fused_moe/layer.py | 4 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 37 +++++++------------ vllm/model_executor/layers/fused_moe/utils.py | 2 +- 8 files changed, 48 insertions(+), 54 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 1a63b3237343..fc30e84e6656 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch @@ -142,7 +142,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool): + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): experts = (self.batched_deep_gemm_experts if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None @@ -150,4 +151,4 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, activation, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, workspace2, expert_tokens_meta, - apply_router_weight_on_input) + apply_router_weight_on_input, extra_expert_args) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 7bb686b2ace3..5d213cd5cce1 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -619,6 +619,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: workspace1: tuple[int, ...] = () workspace2: tuple[int, ...] = () diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index bb462938a392..dee6ad138a81 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import Optional +from typing import Any, Optional import torch @@ -136,6 +136,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], ): assert self.block_shape is not None assert a1q_scale is not None diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index ffb915a44c63..1753c4f6e238 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -88,8 +88,16 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceDelegate() def workspace_shapes( - self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, - topk: int, global_num_experts: int, local_num_experts: int + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # We use global_num_experts due to how moe_align_block_size handles # expert_maps. diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index ab8a281b3901..9a5c85e120cc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused batched MoE kernel.""" -from typing import Optional +from typing import Any, Optional import torch @@ -496,16 +496,12 @@ def num_dispatchers(self) -> int: return self.num_dispatchers_ def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -594,15 +590,11 @@ def prepare( return b_a1, b_a1_scale, expert_tokens_meta, None, None - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank) weight_and_reduce_impl.apply( @@ -706,7 +698,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool): + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): assert hidden_states.dim() == 3 assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens @@ -911,7 +904,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool): + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4ccb943b8f40..c4c5d62122f4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -215,9 +215,7 @@ def maybe_swap_experts_impl( self, moe_parallel_config: FusedMoEParallelConfig, ): - raise NotImplementedError( - f"{self.__class__.__name__} must select appropriate experts " - "implementation based on the moe_parallel_config") + pass @abstractmethod def apply( diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 51b95c9aa922..1b31368c79cd 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch @@ -119,28 +119,18 @@ def workspace_shapes( local_num_experts, expert_tokens_meta) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): use_deep_gemm = (self.allow_deep_gemm and (_valid_deep_gemm(hidden_states, w1, w2) or is_blackwell_deep_gemm_used())) @@ -168,4 +158,5 @@ def apply( workspace2, expert_tokens_meta, apply_router_weight_on_input, + extra_expert_args, ) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 6ac99c94fddc..966471b5c59b 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -255,7 +255,7 @@ def _validate_scale_shape( def extract_required_args( - extra_args: dict[str, Any], + extra_args: Optional[dict[str, Any]], required_keys: list[str], ) -> tuple[Any, ...]: if extra_args is None: From e5d14ea6b116ba577195a643eeece6d128f98e8a Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 17 Jul 2025 13:53:50 -0400 Subject: [PATCH 28/30] Fix run_cutlass_moe_fp4 for Llama 4 Signed-off-by: mgoin --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 5d213cd5cce1..484e36a2cf0b 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -562,7 +562,7 @@ def run_cutlass_moe_fp4( topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1), non_blocking=True) else: - output.copy_(c2.view(m, num_topk, k).sum(dim=1), non_blocking=True) + output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True) return From 4b3ee2e76353570a157fd3bb009ffc485b6c479d Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 17 Jul 2025 17:12:49 -0400 Subject: [PATCH 29/30] Ensure lazy imports for flashinfer Signed-off-by: mgoin --- vllm/utils/flashinfer.py | 108 ++++++++++++++++++--------------------- 1 file changed, 49 insertions(+), 59 deletions(-) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index e0523a20e21c..dbd2dc393046 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -9,6 +9,7 @@ import contextlib import functools import importlib +import importlib.util from typing import Any, Callable, NoReturn from vllm.logger import init_logger @@ -19,11 +20,9 @@ @functools.cache def has_flashinfer() -> bool: """Return ``True`` if FlashInfer is available.""" - try: - import flashinfer # noqa: F401 - return True - except ImportError: - return False + # Use find_spec to check if the module exists without importing it + # This avoids potential CUDA initialization side effects + return importlib.util.find_spec("flashinfer") is not None def _missing(*_: Any, **__: Any) -> NoReturn: @@ -42,69 +41,60 @@ def _get_submodule(module_name: str) -> Any | None: return None -# Initialize FlashInfer components -if not has_flashinfer(): - _cutlass_fused_moe_impl: Callable[..., Any] | None = None - _fp4_quantize_impl: Callable[..., Any] | None = None - _fp4_swizzle_blockscale_impl: Callable[..., Any] | None = None - _autotune_impl: Callable[..., Any] | None = None -else: - # Import main flashinfer module - _fi = importlib.import_module("flashinfer") # type: ignore - - # Import fused_moe submodule - _fused_moe_mod = _get_submodule("flashinfer.fused_moe") - _cutlass_fused_moe_impl = getattr(_fused_moe_mod, "cutlass_fused_moe", - None) if _fused_moe_mod else None - - # Import fp4_quant functions - _fp4_quantize_impl = getattr(_fi, "fp4_quantize", None) if _fi else None - _fp4_swizzle_blockscale_impl = getattr(_fi, "fp4_swizzle_blockscale", - None) if _fi else None +# General lazy import wrapper +def _lazy_import_wrapper(module_name: str, + attr_name: str, + fallback_fn: Callable[..., Any] = _missing): + """Create a lazy import wrapper for a specific function.""" - # Import autotuner submodule - _autotuner_mod = _get_submodule("flashinfer.autotuner") - _autotune_impl = getattr(_autotuner_mod, "autotune", - None) if _autotuner_mod else None - - -@functools.cache -def has_flashinfer_cutlass_fused_moe() -> bool: - """Return ``True`` if FlashInfer CUTLASS fused MoE is available.""" - return all([ - _cutlass_fused_moe_impl, - _fp4_quantize_impl, - _fp4_swizzle_blockscale_impl, - ]) + @functools.cache + def _get_impl(): + if not has_flashinfer(): + return None + mod = _get_submodule(module_name) + return getattr(mod, attr_name, None) if mod else None + def wrapper(*args, **kwargs): + impl = _get_impl() + if impl is None: + return fallback_fn(*args, **kwargs) + return impl(*args, **kwargs) -def flashinfer_cutlass_fused_moe(*args, **kwargs): - """FlashInfer CUTLASS fused MoE kernel.""" - if _cutlass_fused_moe_impl is None: - return _missing(*args, **kwargs) - return _cutlass_fused_moe_impl(*args, **kwargs) + return wrapper -def fp4_quantize(*args, **kwargs): - """FlashInfer FP4 quantization.""" - if _fp4_quantize_impl is None: - return _missing(*args, **kwargs) - return _fp4_quantize_impl(*args, **kwargs) +# Create lazy wrappers for each function +flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", + "cutlass_fused_moe") +fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") +fp4_swizzle_blockscale = _lazy_import_wrapper("flashinfer", + "fp4_swizzle_blockscale") +# Special case for autotune since it returns a context manager +autotune = _lazy_import_wrapper( + "flashinfer.autotuner", + "autotune", + fallback_fn=lambda *args, **kwargs: contextlib.nullcontext()) -def fp4_swizzle_blockscale(*args, **kwargs): - """FlashInfer FP4 swizzle blockscale.""" - if _fp4_swizzle_blockscale_impl is None: - return _missing(*args, **kwargs) - return _fp4_swizzle_blockscale_impl(*args, **kwargs) +@functools.cache +def has_flashinfer_cutlass_fused_moe() -> bool: + """Return ``True`` if FlashInfer CUTLASS fused MoE is available.""" + if not has_flashinfer(): + return False -def autotune(*args, **kwargs): - """FlashInfer autotuner.""" - if _autotune_impl is None: - # return a null context since autotune is a context manager - return contextlib.nullcontext() - return _autotune_impl(*args, **kwargs) + # Check if all required functions are available + required_functions = [ + ("flashinfer.fused_moe", "cutlass_fused_moe"), + ("flashinfer", "fp4_quantize"), + ("flashinfer", "fp4_swizzle_blockscale"), + ] + + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True __all__ = [ From cc0e87d81926d35470b57cadd89775f7eb15aa04 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 17 Jul 2025 19:27:38 -0400 Subject: [PATCH 30/30] Just use has_flashinfer in fused_moe/layer.py Signed-off-by: mgoin --- vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c4c5d62122f4..93adfb81cb1b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -34,7 +34,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx -from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.flashinfer import has_flashinfer if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -46,7 +46,7 @@ from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) - if has_flashinfer_cutlass_fused_moe(): + if has_flashinfer(): from .flashinfer_cutlass_prepare_finalize import ( FlashInferCutlassMoEPrepareAndFinalize) else: