Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] On kuberay, vLLM-0.7.2 reports "No CUDA GPUs are available" while vllm-0.6.6.post1 works fine when deploy rayservice #51154

Open
pteric opened this issue Mar 7, 2025 · 6 comments
Labels
bug Something that is supposed to be working; but isn't serve Ray Serve Related Issue triage Needs triage (eg: priority, bug/not-bug, and owning component)

Comments

@pteric
Copy link

pteric commented Mar 7, 2025

What happened + What you expected to happen

Description

When deploying Qwen2.5-0.5B model using kuberay with vLLM 0.7.2, encountering "RuntimeError: No CUDA GPUs are available" error. However, the same deployment works fine with vLLM 0.6.6.post1 under identical environment conditions.

Environment Information

  • Container Image: rayproject/ray:2.43.0-py39-cu124
  • vLLM:
    • Failed version: 0.7.2
    • Working version: 0.6.6.post1
  • Model: Qwen2.5-0.5B

Steps to Reproduce

Using kuberay to deploy RayService with image rayproject/ray:2.43.0-py39-cu124, the RayService is:

apiVersion: ray.io/v1
kind: RayService
metadata:
  name: qwen2005-0005b-vllm07
spec:
  serveConfigV2: |
    applications:
    - name: llm
      route_prefix: /
      import_path: latest-serve:model
      deployments:
      - name: VLLMDeployment
        num_replicas: 1
        ray_actor_options:
          num_cpus: 4
      runtime_env:
        working_dir: "https://xxx/vllm_script.zip"
        pip: 
          - "vllm==0.7.2"
        env_vars:
          MODEL_ID: "Qwen/Qwen2.5-0.5B"
          TENSOR_PARALLELISM: "1"
          PIPELINE_PARALLELISM: "1"
  rayClusterConfig:
    headGroupSpec:
      rayStartParams:
        dashboard-host: '0.0.0.0'
      template:
        spec:
          containers:
          - name: ray-head
            image: rayproject/ray:2.43.0-py39-cu124
            imagePullPolicy: IfNotPresent
            resources:
              limits:
                cpu: "8"
                memory: "16Gi"
              requests:
                cpu: "2"
                memory: "4Gi"
            ports:
            - containerPort: 6379
              name: gcs-server
            - containerPort: 8265
              name: dashboard
            - containerPort: 10001
              name: client
            - containerPort: 8000
              name: serve
            env:
            - name: HUGGING_FACE_HUB_TOKEN
              valueFrom:
                secretKeyRef:
                  name: hf-secret
                  key: hf_api_token
    workerGroupSpecs:
    - replicas: 1
      minReplicas: 1
      maxReplicas: 2
      groupName: gpu-group
      rayStartParams: {}
      template:
        spec:
          containers:
          - name: llm
            image: rayproject/ray:2.43.0-py39-cu124
            imagePullPolicy: IfNotPresent
            env:
            - name: HUGGING_FACE_HUB_TOKEN
              valueFrom:
                secretKeyRef:
                  name: hf-secret
                  key: hf_api_token
            resources:
              limits:
                cpu: "8"
                memory: "16Gi"
                nvidia.com/gpu: "1"
              requests:
                cpu: "4"
                memory: "8Gi"
                nvidia.com/gpu: "1"
          tolerations:
            - key: "nvidia.com/gpu"
              operator: "Exists"
              effect: "NoSchedule"

and the latest-serve.py in https://xxx/vllm_script.zip is from: https://github.com/ray-project/ray/blob/master/doc/source/serve/doc_code/vllm_openai_example.py

The exception traceback:

[36mray::ServeReplica:llm:VLLMDeployment.initialize_and_get_metadata()�[39m (pid=1886, ip=10.58.29.125, actor_id=c3a99f2865a8a727c40545aa01000000, repr=<ray.serve._private.replica.ServeReplica:llm:VLLMDeployment object at 0x7f9966d21550>)
  File "/home/ray/anaconda3/lib/python3.9/concurrent/futures/_base.py", line 446, in result
    return self.__get_result()
  File "/home/ray/anaconda3/lib/python3.9/concurrent/futures/_base.py", line 391, in __get_result
    raise self._exception
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 965, in initialize_and_get_metadata
    await self._replica_impl.initialize(deployment_config)
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 694, in initialize
    raise RuntimeError(traceback.format_exc()) from None
RuntimeError: Traceback (most recent call last):
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 671, in initialize
    self._user_callable_asgi_app = await asyncio.wrap_future(
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 1363, in initialize_callable
    await self._call_func_or_gen(
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 1324, in _call_func_or_gen
    result = callable(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/api.py", line 221, in __init__
    cls.__init__(self, *args, **kwargs)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/working_dir_files/https_aistudio-ant-mpc_oss-cn-zhangjiakou_aliyuncs_com_pengtuo_kuberay_vllm_script/latest-serve.py", line 57, in __init__
    self.engine = AsyncLLMEngine.from_engine_args(engine_args)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 644, in from_engine_args
    engine = cls(
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 594, in __init__
    self.engine = self._engine_class(*args, **kwargs)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 267, in __init__
    super().__init__(*args, **kwargs)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 273, in __init__
    self.model_executor = executor_class(vllm_config=vllm_config, )
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/executor/executor_base.py", line 51, in __init__
    self._init_executor()
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/executor/uniproc_executor.py", line 41, in _init_executor
    self.collective_rpc("init_device")
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/executor/uniproc_executor.py", line 51, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/utils.py", line 2220, in run_method
    return func(*args, **kwargs)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/worker/worker.py", line 155, in init_device
    torch.cuda.set_device(self.device)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/torch/cuda/__init__.py", line 478, in set_device
    torch._C._cuda_setDevice(device)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/torch/cuda/__init__.py", line 319, in _lazy_init
    torch._C._cuda_init()
RuntimeError: No CUDA GPUs are available

Related issue

I've been searching for solutions and found two issues that match my symptoms, but the solutions provided in those issues don't work in my case:
vllm-project/vllm#6896
#50275

Versions / Dependencies

Ray image: rayproject/ray:2.43.0-py39-cu124
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.9.21 | packaged by conda-forge | (main, Dec  5 2024, 13:51:40)  [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.10.134-18.al8.x86_64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A10
Nvidia driver version: 550.144.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.1.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.570.86
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.1
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.49.0
[pip3] triton==3.1.0

NVIDIA_VISIBLE_DEVICES=0
NVIDIA_REQUIRE_CUDA=cuda>=12.4 brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 brand=tesla,driver>=525,driver<526 brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526 brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526 brand=tesla,driver>=535,driver<536 brand=unknown,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=geforce,driver>=535,driver<536 brand=geforcertx,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=titan,driver>=535,driver<536 brand=titanrtx,driver>=535,driver<536
NCCL_VERSION=2.21.5-1
NVIDIA_DRIVER_CAPABILITIES=compute,utility
NVIDIA_PRODUCT_NAME=CUDA
CUDA_VERSION=12.4.1
LD_LIBRARY_PATH=/tmp/ray/session_2025-03-06_04-45-27_752822_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/cv2/../../lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

Reproduction script

the vLLM deployment code:

import os

from typing import Dict, Optional, List
import logging

from fastapi import FastAPI
from starlette.requests import Request
from starlette.responses import StreamingResponse, JSONResponse

from ray import serve

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ErrorResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import (
    BaseModelPath,
    LoRAModulePath,
    PromptAdapterPath,
    OpenAIServingModels,
)

from vllm.utils import FlexibleArgumentParser
from vllm.entrypoints.logger import RequestLogger

logger = logging.getLogger("ray.serve")

app = FastAPI()


@serve.deployment(name="VLLMDeployment")
@serve.ingress(app)
class VLLMDeployment:
    def __init__(
        self,
        engine_args: AsyncEngineArgs,
        response_role: str,
        lora_modules: Optional[List[LoRAModulePath]] = None,
        prompt_adapters: Optional[List[PromptAdapterPath]] = None,
        request_logger: Optional[RequestLogger] = None,
        chat_template: Optional[str] = None,
    ):
        logger.info(f"Starting with engine args: {engine_args}")
        self.openai_serving_chat = None
        self.engine_args = engine_args
        self.response_role = response_role
        self.lora_modules = lora_modules
        self.prompt_adapters = prompt_adapters
        self.request_logger = request_logger
        self.chat_template = chat_template
        self.engine = AsyncLLMEngine.from_engine_args(engine_args)

    @app.post("/v1/chat/completions")
    async def create_chat_completion(self, request: ChatCompletionRequest, raw_request: Request):
        if not self.openai_serving_chat:
            model_config = await self.engine.get_model_config()
            models = OpenAIServingModels(
                self.engine,
                model_config,
                [
                    BaseModelPath(
                        name=self.engine_args.model, model_path=self.engine_args.model
                    )
                ],
                lora_modules=self.lora_modules,
                prompt_adapters=self.prompt_adapters,
            )
            self.openai_serving_chat = OpenAIServingChat(
                self.engine,
                model_config,
                models,
                self.response_role,
                request_logger=self.request_logger,
                chat_template=self.chat_template,
                chat_template_content_format="auto",
            )
        logger.info(f"Request: {request}")
        generator = await self.openai_serving_chat.create_chat_completion(
            request, raw_request
        )
        if isinstance(generator, ErrorResponse):
            return JSONResponse(
                content=generator.model_dump(), status_code=generator.code
            )
        if request.stream:
            return StreamingResponse(content=generator, media_type="text/event-stream")
        else:
            assert isinstance(generator, ChatCompletionResponse)
            return JSONResponse(content=generator.model_dump())


def parse_vllm_args(cli_args: Dict[str, str]):
    arg_parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server."
    )
    parser = make_arg_parser(arg_parser)
    arg_strings = []
    for key, value in cli_args.items():
        arg_strings.extend([f"--{key}", str(value)])
    logger.info(arg_strings)
    parsed_args = parser.parse_args(args=arg_strings)
    return parsed_args


# serve run latest-serve:build_app model="Qwen/Qwen2.5-0.5B" tensor-parallel-size=1 accelerator="GPU"
def build_app(cli_args: Dict[str, str]) -> serve.Application:
    logger.info("*" * 100)
    if "accelerator" in cli_args.keys():
        accelerator = cli_args.pop("accelerator")
    else:
        accelerator = "GPU"
    parsed_args = parse_vllm_args(cli_args)

    engine_args = AsyncEngineArgs.from_cli_args(parsed_args)
    engine_args.worker_use_ray = True

    tp = engine_args.tensor_parallel_size
    logger.info(f"Tensor parallelism = {tp}")
    pg_resources = []
    pg_resources.append({"CPU": 4})  # for the deployment replica
    for i in range(tp):
        pg_resources.append({"CPU": 2, accelerator: 1})  # for the vLLM actors

    return VLLMDeployment.options(
        placement_group_bundles=pg_resources, placement_group_strategy="SPREAD"
    ).bind(
        engine_args,
        parsed_args.response_role,
        parsed_args.lora_modules,
        parsed_args.prompt_adapters,
        cli_args.get("request_logger"),
        parsed_args.chat_template,
    )

model = build_app({
    "model": os.environ['MODEL_ID'],
    "port": "8080",
    "tensor-parallel-size": os.environ['TENSOR_PARALLELISM'],
    "pipeline-parallel-size": os.environ['PIPELINE_PARALLELISM'],
    "max-model-len": os.environ['MODEL_LEN'],
    "gpu-memory-utilization": os.environ['GPU_MEMORY_UTILIZATION'],
    "dtype": os.environ['DTYPE'],
    "kv-cache-dtype": os.environ['KV_CACHE_DTYPE']
})

Issue Severity

High: It blocks me from completing my task.

@pteric pteric added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Mar 7, 2025
@pteric pteric changed the title [<Ray component: Core|RLlib|etc...>] vLLM-0.7.2 reports "No CUDA GPUs are available" while vllm-0.6.6.post1 works fine on kuberay. [Serve] On kuberay, vLLM-0.7.2 reports "No CUDA GPUs are available" while vllm-0.6.6.post1 works fine when deploy rayservice Mar 7, 2025
@huiyeruzhou
Copy link

same issue

@huiyeruzhou
Copy link

huiyeruzhou commented Mar 10, 2025

Issue Summary:

I think I have identified the issue in RayDistributedExecutor. Specifically, the AsyncLLMEngine ultimately dispatches world_size workers, one of which is a local actor instead of a remote Ray actor.

Therefore, in the Ray Serve framework, the deployment replica actor itself (which uses the first bundle in the placement group) will need GPU resources to create the local worker that will be initialized by AsyncLLMEngine.

In Ray's demo code, no GPU resources are allocated for the deployment replica actor, which leads to a CUDA error when the method of RayDistributedExecutor.driver_worker(the local worker) called in RayDistributedExecutor._run_workers

    pg_resources.append({"CPU": 4})  # for the deployment replica, CPU ONLY!

Current Problem:

However, simply adding GPU resources to the first bundle may not resolve the issue. This is because vLLM creates a dummy Ray actor to hold the resources used by the local actor. If GPU resources are already allocated for the deployment replica actor, this results in an extra resource allocation, causing an infinite waiting state.

Request for Feedback:

I’m just starting to investigate this issue, so there may be inaccuracies in my understanding. Welcome to any comments, suggestions, or corrections! Let me know if you have insights or ideas to address this problem.

@pang-wu
Copy link

pang-wu commented Mar 10, 2025

We run into same issue when trying to serve Qwen2.5 VL AWQ in KubeRay. Running vllm serve in the same pod don't have the problem.

@huiyeruzhou
Copy link

Hi! There is a HACK technique that works for me. (In vllm 0.6.3, I'm not sure if it is still available in 0.7.0+, once I figure out it I will open a PR)

Based on the given understanding, I altimately find that we can write a branch to hack the code that is related to dummy worker(the resource placeholder for local worker).

The key point is that ray.get_runtime_context().get_actor_id() will return None(and thorw a warning, ignore it :) if the code is not running in an actor. We can utilize it as a detector of our environment: whether we are using AsyncLLMEngine directly or in an actor such as a deployment replica actor(the senario of ray serve)

Firstly, we skip one bundle in the worker creation loop to avoid extra resource allocation, add the following code:

        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
            if not bundle.get("GPU", 0):
                continue
            #### BEGIN
            if ray.get_runtime_context().get_actor_id():
                # since we are in an actor, we should not create another dummy worker.
                self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
                self.driver_dummy_worker = 1 # HACK!! just becase there will be a None check for dummy_worker
                if not ray.get_gpu_ids():
                    # instead of checking the dummy worker, we directly check gpu allocation of the current actor
                    raise ValueError(
                        "Ray does not allocate any GPUs on the driver node. Consider "
                        "adjusting the Ray placement group or running the driver on a "
                        "GPU node.")
                continue
            #### END

Then, we get the ray context info without dummy worker:

origin:

worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
                                                        use_dummy_driver=True)

changed:

        ### BEGIN
        if ray.get_runtime_context().get_actor_id():
            # worker_driver should be enough to get ray context
            worker_node_and_gpu_ids =  self._run_workers("get_node_and_gpu_ids")
        else:
            worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
                                                        use_dummy_driver=True)
        ### END

finally, change the ray serve entrypoint, note that we reduce one bundle and add ray_actor_options arguments to specify the resources of the deployment replica actor

    pg_resources = []
    # Deployment replica will also use GPU for AsyncLLMEngine.
	### BEGIN
    for i in range(tp):
        pg_resources.append({"CPU": 1, "GPU": 1})  # for the vLLM actors, 
    # We use the "STRICT_PACK" strategy below to ensure all vLLM actors are placed on
    # the same Ray node.
    return VLLMDeployment.options(
        # allocate resource for the deployment replica actor
        ray_actor_options={
            "num_gpus": 1,
            "num_cpus": 1,
        },
        placement_group_bundles=pg_resources,
        placement_group_strategy="STRICT_PACK",
    ).bind(
	### END

@kouroshHakha
Copy link
Contributor

@huiyeruzhou can you try the new native llm api in ray serve and see if the issue persists?

@huiyeruzhou
Copy link

@huiyeruzhou can you try the new native llm api in ray serve and see if the issue persists?

Hi! Here are my experiment findings. I discovered that the Ray Serve and Ray LLM APIs are equivalent, with the placement group (PG) configuration being the key factor.

The default PG configuration is: [{'CPU': 1} + {'GPU': 1} * TP]. This setup works when TP > 1 but fails when TP = 1.

For detailed analysis: #51242

VLLM Version Placement Group Configuration TP Status Notes
VLLM 0.7.3 [{'CPU':1} + {'GPU':1} * TP] >1 ✅ Works Replica actor has no GPU but gains access via update_environment_variables
VLLM 0.7.3 [{'GPU':1} * TP] >1 ❌ Fails Extra worker creation causes deadlock due to loop in ray_distributed_executor.py#L187
VLLM 0.7.3 [{'CPU':1} + {'GPU':1} * TP] 1 ❌ Fails Replica actor has no GPU, and Executor can no longer "borrow" CUDA_VISIBLE_DEVICES
VLLM 0.7.3 [{'GPU':1} * TP] 1 ✅ Works Replica actor has no GPU, but uniproc_executor avoids dummy worker creation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't serve Ray Serve Related Issue triage Needs triage (eg: priority, bug/not-bug, and owning component)
Projects
None yet
Development

No branches or pull requests

5 participants