diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 90134683180a..1ba50fec930a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1290,7 +1290,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # Skip this check if we are running on a non-GPU platform, # or if the device capability is not available # (e.g. in a Ray actor without GPUs). - from vllm.platforms import current_platform + from vllm.platforms import CpuArchEnum, current_platform if (current_platform.is_cuda() and current_platform.get_device_capability() and current_platform.get_device_capability().major < 8): @@ -1434,7 +1434,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # Non-[CUDA, TPU] may be supported on V1, but off by default for now. v0_hardware = not any( (current_platform.is_cuda(), current_platform.is_tpu(), - current_platform.is_cpu())) + (current_platform.is_cpu() + and current_platform.get_cpu_architecture() == CpuArchEnum.X86))) if v0_hardware and _warn_or_fallback( # noqa: SIM103 current_platform.device_name): return False diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 265959d626e0..71c964fbfbb5 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +import platform import sys from importlib.util import find_spec from typing import TYPE_CHECKING, Optional @@ -22,6 +23,15 @@ VllmConfig = None +def get_max_threads(pid=0): + if hasattr(os, 'sched_getaffinity'): + return len(os.sched_getaffinity(pid)) + elif platform.system() == 'Darwin': + return os.cpu_count() + else: + raise NotImplementedError("Unsupported OS") + + class CpuPlatform(Platform): _enum = PlatformEnum.CPU device_name: str = "cpu" @@ -190,7 +200,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # Note: to avoid the error 'nthreads cannot be larger than environment # variable "NUMEXPR_MAX_THREADS" (64)'. - os.environ["NUMEXPR_MAX_THREADS"] = str(len(os.sched_getaffinity(0))) + os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads()) # Set default threads num for OpenMP parallel os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())