Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions tests/kernels/moe/test_deepep_deepgemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""

import dataclasses
import importlib
from typing import Optional

import pytest
Expand All @@ -21,38 +20,33 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm

from .utils import ProcessGroupInfo, parallel_launch

has_deep_ep = importlib.util.find_spec("deep_ep") is not None

try:
import deep_gemm
has_deep_gemm = True
except ImportError:
has_deep_gemm = False

if has_deep_ep:
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)

from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a

if has_deep_gemm:
if has_deep_gemm():
import deep_gemm

from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)

requires_deep_ep = pytest.mark.skipif(
not has_deep_ep,
not has_deep_ep(),
reason="Requires deep_ep kernels",
)

requires_deep_gemm = pytest.mark.skipif(
not has_deep_gemm,
not has_deep_gemm(),
reason="Requires deep_gemm kernels",
)

Expand Down
8 changes: 3 additions & 5 deletions tests/kernels/moe/test_deepep_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

import dataclasses
import importlib
from typing import Optional, Union

import pytest
Expand All @@ -22,12 +21,11 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep

from .utils import ProcessGroupInfo, parallel_launch

has_deep_ep = importlib.util.find_spec("deep_ep") is not None

if has_deep_ep:
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
Expand All @@ -36,7 +34,7 @@
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a

requires_deep_ep = pytest.mark.skipif(
not has_deep_ep,
not has_deep_ep(),
reason="Requires deep_ep kernels",
)

Expand Down
10 changes: 5 additions & 5 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
from typing import TYPE_CHECKING, Any

import torch
import torch.distributed as dist

from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx

from .base_device_communicator import All2AllManagerBase, Cache

Expand Down Expand Up @@ -80,8 +80,8 @@ class PPLXAll2AllManager(All2AllManagerBase):
"""

def __init__(self, cpu_group):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
assert has_pplx(
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)

if self.internode:
Expand Down Expand Up @@ -133,8 +133,8 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""

def __init__(self, cpu_group):
has_deepep = importlib.util.find_spec("deep_ep") is not None
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
assert has_deep_ep(
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
super().__init__(cpu_group)
self.handle_cache = Cache()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Optional

import torch
Expand All @@ -11,8 +10,6 @@

logger = init_logger(__name__)

has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None


@triton.jit
def _silu_mul_fp8_quant_deep_gemm(
Expand Down
12 changes: 5 additions & 7 deletions vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import importlib.util
from typing import Optional

import torch
Expand All @@ -12,14 +11,13 @@
_moe_permute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, per_token_group_quant_fp8)
from vllm.utils import round_up
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm, round_up

logger = init_logger(__name__)

has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None


@functools.cache
def deep_gemm_block_shape() -> list[int]:
Expand All @@ -41,7 +39,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if not has_deep_gemm:
if not has_deep_gemm():
logger.debug("DeepGemm disabled: deep_gemm not available.")
return False

Expand Down
10 changes: 3 additions & 7 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import importlib
from abc import abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
Expand Down Expand Up @@ -32,20 +31,17 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op

has_pplx = importlib.util.find_spec("pplx_kernels") is not None
has_deepep = importlib.util.find_spec("deep_ep") is not None
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx

if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
from .fused_moe import TritonExperts, fused_experts
from .modular_kernel import (FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
if has_pplx:
if has_pplx():
from .pplx_prepare_finalize import PplxPrepareAndFinalize
if has_deepep:
if has_deep_ep():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
DeepEPLLPrepareAndFinalize)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
return s.getsockname()[1]
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import enum
import importlib
from enum import Enum
from typing import Callable, Optional

Expand All @@ -29,13 +28,12 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types

has_pplx = importlib.util.find_spec("pplx_kernels") is not None
from vllm.utils import has_pplx

if current_platform.is_cuda_alike():
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
if has_pplx:
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)

Expand Down Expand Up @@ -577,7 +575,7 @@ def select_gemm_impl(self, prepare_finalize, moe):
use_batched_format=True,
)

if has_pplx and isinstance(
if has_pplx() and isinstance(
prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
# no expert_map support in this case
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/layers/quantization/deepgemm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
import logging

import torch

from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op
from vllm.utils import direct_register_custom_op, has_deep_gemm

has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
if has_deep_gemm:
if has_deep_gemm():
import deep_gemm

logger = logging.getLogger(__name__)
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import functools
import importlib.util
from typing import Any, Callable, Optional, Union

import torch
Expand Down Expand Up @@ -38,13 +37,12 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm

ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None


def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
Expand Down Expand Up @@ -451,7 +449,7 @@ def __init__(self, quant_config: Fp8Config):
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm:
if not has_deep_gemm():
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# Adapted from https://github.com/sgl-project/sglang/pull/2575
import functools
import importlib.util
import json
import os
from typing import Any, Callable, Optional, Union
Expand All @@ -19,10 +18,9 @@
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm

logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None


def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
Expand Down Expand Up @@ -109,7 +107,7 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
"""

return (current_platform.is_cuda()
and current_platform.is_device_capability(90) and has_deep_gemm
and current_platform.is_device_capability(90) and has_deep_gemm()
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)

Expand Down
28 changes: 28 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2929,3 +2929,31 @@ def is_torch_equal_or_newer(target: str) -> bool:
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
torch_version = version.parse(torch_version)
return torch_version >= version.parse(target)


@cache
def _has_module(module_name: str) -> bool:
"""Return True if *module_name* can be found in the current environment.

The result is cached so that subsequent queries for the same module incur
no additional overhead.
"""
return importlib.util.find_spec(module_name) is not None


def has_pplx() -> bool:
"""Whether the optional `pplx_kernels` package is available."""

return _has_module("pplx_kernels")


def has_deep_ep() -> bool:
"""Whether the optional `deep_ep` package is available."""

return _has_module("deep_ep")


def has_deep_gemm() -> bool:
"""Whether the optional `deep_gemm` package is available."""

return _has_module("deep_gemm")