Skip to content

Commit 4ebc910

Browse files
authored
[Kernel] Centralize platform kernel import in current_platform.import_kernels (#26286)
Signed-off-by: NickLucche <nlucches@redhat.com>
1 parent e1ba235 commit 4ebc910

File tree

4 files changed

+13
-15
lines changed

4 files changed

+13
-15
lines changed

vllm/_custom_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212

1313
logger = init_logger(__name__)
1414

15-
current_platform.import_core_kernels()
16-
supports_moe_ops = current_platform.try_import_moe_kernels()
15+
current_platform.import_kernels()
1716

1817
if TYPE_CHECKING:
1918

@@ -1921,7 +1920,7 @@ def moe_wna16_marlin_gemm(
19211920
)
19221921

19231922

1924-
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
1923+
if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
19251924

19261925
@register_fake("_moe_C::marlin_gemm_moe")
19271926
def marlin_gemm_moe_fake(

vllm/platforms/interface.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,22 +170,15 @@ def device_id_to_physical_device_id(cls, device_id: int):
170170
return device_id
171171

172172
@classmethod
173-
def import_core_kernels(cls) -> None:
173+
def import_kernels(cls) -> None:
174174
"""Import any platform-specific C kernels."""
175175
try:
176176
import vllm._C # noqa: F401
177177
except ImportError as e:
178178
logger.warning("Failed to import from vllm._C: %r", e)
179-
180-
@classmethod
181-
def try_import_moe_kernels(cls) -> bool:
182-
"""Import any platform-specific MoE kernels."""
183179
with contextlib.suppress(ImportError):
184180
import vllm._moe_C # noqa: F401
185181

186-
return True
187-
return False
188-
189182
@classmethod
190183
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
191184
from vllm.attention.backends.registry import _Backend

vllm/platforms/tpu.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import contextlib
45
from typing import TYPE_CHECKING, Optional, Union, cast
56

67
import torch
@@ -45,8 +46,10 @@ class TpuPlatform(Platform):
4546
additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
4647

4748
@classmethod
48-
def import_core_kernels(cls) -> None:
49-
pass
49+
def import_kernels(cls) -> None:
50+
# Do not import vllm._C
51+
with contextlib.suppress(ImportError):
52+
import vllm._moe_C # noqa: F401
5053

5154
@classmethod
5255
def get_attn_backend_cls(

vllm/platforms/xpu.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import contextlib
45
import os
56
from typing import TYPE_CHECKING, Optional
67

@@ -35,8 +36,10 @@ class XPUPlatform(Platform):
3536
device_control_env_var: str = "ZE_AFFINITY_MASK"
3637

3738
@classmethod
38-
def import_core_kernels(cls) -> None:
39-
pass
39+
def import_kernels(cls) -> None:
40+
# Do not import vllm._C
41+
with contextlib.suppress(ImportError):
42+
import vllm._moe_C # noqa: F401
4043

4144
@classmethod
4245
def get_attn_backend_cls(

0 commit comments

Comments
 (0)