Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: No Cuda GPUs are available when running vLLM on Ray (Qwen 2.5 VL AWQ) #14456

Open
1 task done
Fmak95 opened this issue Mar 7, 2025 · 3 comments
Open
1 task done
Labels
bug Something isn't working ray anything related with ray

Comments

@Fmak95
Copy link

Fmak95 commented Mar 7, 2025

Your current environment

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Chainguard (x86_64)
GCC version: (Wolfi 14.2.0-r4) 14.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.40

Python version: 3.10.15 (tags/v3.10.15-0-gffee63f-dirty:ffee63f, Sep 23 2024, 21:00:09) [GCC 14.2.0] (64-bit runtime)
Python platform: Linux-5.10.227-219.884.amzn2.x86_64-x86_64-with-glibc2.40
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA L40S
GPU 1: NVIDIA L40S
GPU 2: NVIDIA L40S
GPU 3: NVIDIA L40S

Nvidia driver version: 550.127.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
/bin/sh: lscpu: not found

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.570.86
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] optree==0.14.0
[pip3] pyzmq==26.2.1
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.49.0.dev0
[pip3] triton==3.1.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.7.2
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0	GPU1	GPU2	GPU3	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	SYS	SYS	SYS	0-47	0		N/A
GPU1	SYS	 X 	SYS	SYS	0-47	0		N/A
GPU2	SYS	SYS	 X 	SYS	0-47	0		N/A
GPU3	SYS	SYS	SYS	 X 	0-47	0		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NVIDIA_VISIBLE_DEVICES=GPU-b082cd89-9bbb-de73-9315-68108bc20cee,GPU-8a6a364e-e153-e010-ca44-a7f769347275,GPU-70a9674c-c409-363d-d360-cbcbaf04741a,GPU-2f017e94-5df7-bcc0-fce2-defee9018d3c
NVIDIA_DRIVER_CAPABILITIES=compute,utility
NCCL_P2P_LEVEL=NVL
VLLM_ENGINE_ITERATION_TIMEOUT_S=1500
VLLM_ATTENTION_BACKEND=FLASH_ATTN
CUDA_VISIBLE_DEVICES=0,1,2,3
CUDA_VISIBLE_DEVICES=0,1,2,3
VLLM_FLASH_ATTN_VERSION=2
CUDA_HOME=/usr
CUDA_HOME=/usr
VLLM_LOGGING_LEVEL=DEBUG
LD_LIBRARY_PATH=/home/ray/venv/lib/python3.10/site-packages/cv2/../../lib64:
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

🐛 Describe the bug

Hi all,

I am hosting vLLM on Ray Serve on a Kubernetes cluster (AWS EKS) and I run into a strange issue when trying to host the AWQ quantized Qwen2.5-VL-7B-Instruct-AWQ models. The stack trace implies that there is no CUDA GPU available, but I can detect my GPUs when I SSH into the pod and run nvidia-smi.

This error seems specific to the AWQ quantized model versions. When I run the same configs using the non-quantized QWEN models, it is successful.

I will share the stack trace below:

Stacktrace ``` File "/usr/lib/python3.10/concurrent/futures/_base.py", line 451, in result return self.__get_result() File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/ca98275b0a29b201c71e0a5ee4919249e83e7317/virtualenv/lib/python3.10/site-packages/ray/serve/_private/replica.py", line 952, in initialize_and_get_metadata await self._replica_impl.initialize(deployment_config) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/ca98275b0a29b201c71e0a5ee4919249e83e7317/virtualenv/lib/python3.10/site-packages/ray/serve/_private/replica.py", line 687, in initialize raise RuntimeError(traceback.format_exc()) from None RuntimeError: Traceback (most recent call last): File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/ca98275b0a29b201c71e0a5ee4919249e83e7317/virtualenv/lib/python3.10/site-packages/ray/serve/_private/replica.py", line 664, in initialize self._user_callable_asgi_app = await asyncio.wrap_future( File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/ca98275b0a29b201c71e0a5ee4919249e83e7317/virtualenv/lib/python3.10/site-packages/ray/serve/_private/replica.py", line 1350, in initialize_callable await self._call_func_or_gen( File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/ca98275b0a29b201c71e0a5ee4919249e83e7317/virtualenv/lib/python3.10/site-packages/ray/serve/_private/replica.py", line 1311, in _call_func_or_gen result = callable(*args, **kwargs) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/ca98275b0a29b201c71e0a5ee4919249e83e7317/virtualenv/lib/python3.10/site-packages/ray/serve/api.py", line 221, in __init__ cls.__init__(self, *args, **kwargs) File "/home/ray/src/vllm_florence2_prototype/api_server.py", line 51, in __init__ self.engine = AsyncLLMEngine.from_engine_args(engine_args) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/ca98275b0a29b201c71e0a5ee4919249e83e7317/virtualenv/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 639, in from_engine_args engine_config = engine_args.create_engine_config(usage_context) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/ca98275b0a29b201c71e0a5ee4919249e83e7317/virtualenv/lib/python3.10/site-packages/vllm/engine/arg_utils.py", line 1276, in create_engine_config config = VllmConfig( File "", line 19, in __init__ File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/ca98275b0a29b201c71e0a5ee4919249e83e7317/virtualenv/lib/python3.10/site-packages/vllm/config.py", line 3225, in __post_init__ self.quant_config = VllmConfig._get_quantization_config( File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/ca98275b0a29b201c71e0a5ee4919249e83e7317/virtualenv/lib/python3.10/site-packages/vllm/config.py", line 3181, in _get_quantization_config raise ValueError( ValueError: torch.bfloat16 is not supported for quantization method awq. Supported dtypes: [torch.float16] qwen2-vl-72b-instruct: status: RUNNING message: '' last_deployed_time_s: 1741284567.6790574 deployments: VLLMDeployment: status: HEALTHY status_trigger: CONFIG_UPDATE_COMPLETED replica_states: RUNNING: 1 message: '' qwen2-vl-72b-instruct-pt2: status: DEPLOY_FAILED message: Failed to update the deployments ['VLLMDeployment']. last_deployed_time_s: 1741284567.6790574 deployments: VLLMDeployment: status: DEPLOY_FAILED status_trigger: REPLICA_STARTUP_FAILED replica_states: {} message: |- The deployment failed to start 3 times in a row. This may be due to a problem with its constructor or initial health check failing. See controller logs for details. Error: ray::ServeReplica:qwen2-vl-72b-instruct-pt2:VLLMDeployment.initialize_and_get_metadata() (pid=1535, ip=10.191.48.218, actor_id=fd4051c85f9ba380d843991501000000, repr=) File "/usr/lib/python3.10/concurrent/futures/_base.py", line 451, in result return self.__get_result() File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/ray/serve/_private/replica.py", line 952, in initialize_and_get_metadata await self._replica_impl.initialize(deployment_config) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/ray/serve/_private/replica.py", line 687, in initialize raise RuntimeError(traceback.format_exc()) from None RuntimeError: Traceback (most recent call last): File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/ray/serve/_private/replica.py", line 664, in initialize self._user_callable_asgi_app = await asyncio.wrap_future( File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/ray/serve/_private/replica.py", line 1350, in initialize_callable await self._call_func_or_gen( File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/ray/serve/_private/replica.py", line 1311, in _call_func_or_gen result = callable(*args, **kwargs) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/ray/serve/api.py", line 221, in __init__ cls.__init__(self, *args, **kwargs) File "/home/ray/src/vllm_florence2_prototype/api_server.py", line 51, in __init__ self.engine = AsyncLLMEngine.from_engine_args(engine_args) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 644, in from_engine_args engine = cls( File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 594, in __init__ self.engine = self._engine_class(*args, **kwargs) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 267, in __init__ super().__init__(*args, **kwargs) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 273, in __init__ self.model_executor = executor_class(vllm_config=vllm_config, ) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 262, in __init__ super().__init__(*args, **kwargs) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 51, in __init__ self._init_executor() File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/executor/ray_distributed_executor.py", line 90, in _init_executor self._init_workers_ray(placement_group) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/executor/ray_distributed_executor.py", line 355, in _init_workers_ray self._run_workers("init_device") File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/executor/ray_distributed_executor.py", line 476, in _run_workers self.driver_worker.execute_method(sent_method, *args, **kwargs) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 575, in execute_method raise e File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 566, in execute_method return run_method(target, method, args, kwargs) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/utils.py", line 2220, in run_method return func(*args, **kwargs) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/vllm/worker/worker.py", line 155, in init_device torch.cuda.set_device(self.device) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/torch/cuda/__init__.py", line 478, in set_device torch._C._cuda_setDevice(device) File "/tmp/ray/session_2025-03-06_15-58-41_147160_1/runtime_resources/pip/c20f8b0f241e058835c7df2d61cb7d5ba10c9661/virtualenv/lib/python3.10/site-packages/torch/cuda/__init__.py", line 319, in _lazy_init torch._C._cuda_init() RuntimeError: No CUDA GPUs are available ```

nvidia-smi output
Image

Here are my Ray Serve configs and entrypoint script:

    - name: qwen2-vl-7b-instruct
      route_prefix: /qwen2-vl-7b-instruct
      import_path: vllm_florence2_prototype.api_server:build_app
      runtime_env:
        pip:
          - "git+https://github.com/huggingface/transformers.git@11afab19c0e4b652855f9ed7f82aa010c4f14754"
          - "vllm[video]==0.7.2"
          - "qwen-vl-utils[decord]"
          - "ninja"
      deployments:
      - name: VLLMDeployment
        max_ongoing_requests: 1000
        autoscaling_config:
          min_replicas: 1
          max_replicas: 5
      args:
        model: "Qwen/Qwen2.5-VL-7B-Instruct-AWQ"
        tensor_parallel_size: 4
        max_model_len: 32768
        trust_remote_code: true
        dtype: auto
        device: "auto"

Entrypoint script:

import os

from typing import Dict, Optional, List
import logging

from fastapi import FastAPI
import pkg_resources
from starlette.requests import Request
from starlette.responses import StreamingResponse, JSONResponse

from ray import serve

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import RayPrometheusStatLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ErrorResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import (
    BaseModelPath,
    LoRAModulePath,
    OpenAIServingModels,
)
from vllm.utils import FlexibleArgumentParser

logger = logging.getLogger("ray.serve")

app = FastAPI()


@serve.deployment(name="VLLMDeployment")
@serve.ingress(app)
class VLLMDeployment:
    def __init__(
        self,
        engine_args: AsyncEngineArgs,
        response_role: str,
        lora_modules: Optional[List[LoRAModulePath]] = None,
        chat_template: Optional[str] = None,
    ):
        logger.info(f"Starting with engine args: {engine_args}")
        self.openai_serving_chat = None
        self.engine_args = engine_args
        self.response_role = response_role
        self.lora_modules = lora_modules
        self.chat_template = chat_template
        self.engine = AsyncLLMEngine.from_engine_args(engine_args)

    @app.post("/v1/chat/completions")
    async def create_chat_completion(
        self, request: ChatCompletionRequest, raw_request: Request
    ):
        """OpenAI-compatible HTTP endpoint.

        API reference:
            - https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
        """
        if not self.openai_serving_chat:
            model_config = await self.engine.get_model_config()
            # Determine the name of the served model for the OpenAI client.
            if self.engine_args.served_model_name is not None:
                served_model_names = self.engine_args.served_model_name
            else:
                served_model_names = [self.engine_args.model]

            base_model_paths = [
                BaseModelPath(name=name, model_path=self.engine_args.model)
                for name in served_model_names
            ]

            openai_serving_models = OpenAIServingModels(
                engine_client=self.engine,
                model_config=model_config,
                base_model_paths=base_model_paths,
            )

            self.openai_serving_chat = OpenAIServingChat(
                self.engine,
                model_config,
                openai_serving_models,
                response_role=self.response_role,
                chat_template=self.chat_template,
                chat_template_content_format="auto",
                request_logger=None,
            )
        logger.info(f"Request: {request}")
        generator = await self.openai_serving_chat.create_chat_completion(
            request, raw_request
        )
        if isinstance(generator, ErrorResponse):
            return JSONResponse(
                content=generator.model_dump(), status_code=generator.code
            )
        if request.stream:
            return StreamingResponse(content=generator, media_type="text/event-stream")
        else:
            assert isinstance(generator, ChatCompletionResponse)
            return JSONResponse(content=generator.model_dump())


def parse_vllm_args(cli_args: Dict[str, str]):
    """Parses vLLM args based on CLI inputs.

    Currently uses argparse because vLLM doesn't expose Python models for all of the
    config options we want to support.
    """
    parser = FlexibleArgumentParser(description="vLLM CLI")
    parser = make_arg_parser(parser)
    arg_strings = []
    for key, value in cli_args.items():
        arg_strings.extend([f"--{key}", str(value)])
    logger.info(arg_strings)
    parsed_args = parser.parse_args(args=arg_strings)
    return parsed_args


def build_app(cli_args: Dict[str, str]) -> serve.Application:
    """Builds the Serve app based on CLI arguments.

    See https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server
    for the complete set of arguments.

    Supported engine arguments: https://docs.vllm.ai/en/latest/models/engine_args.html.
    """  # noqa: E501

    enforce_eager = cli_args.pop("enforce_eager", False)
    trust_remote_code = cli_args.pop("trust_remote_code", False)
    disable_custom_all_reduce = cli_args.pop("disable_custom_all_reduce", False)
    disable_frontend_multiprocessing = cli_args.pop(
        "disable_frontend_multiprocessing", False
    )
    disable_async_output_proc = cli_args.pop("disable_async_output_proc", False)

    logger.info(f"CLI ARGS: {cli_args}")
    parsed_args = parse_vllm_args(cli_args)
    logger.info(f"PARSED ARGS: {parsed_args}")
    engine_args = AsyncEngineArgs.from_cli_args(parsed_args)

    logger.info(f"ENFORCE EAGER: {enforce_eager}")
    logger.info(f"TRUST REMOTE CODE: {trust_remote_code}")

    if enforce_eager:
        engine_args.enforce_eager = True
    if trust_remote_code:
        engine_args.trust_remote_code = True
    if disable_custom_all_reduce:
        engine_args.disable_custom_all_reduce = True
    if disable_frontend_multiprocessing:
        engine_args.disable_frontend_multiprocessing = True
    if disable_async_output_proc:
        engine_args.disable_async_output_proc = True

    logger.info(f"ENGINE ARGS: {engine_args}")
    # engine_args.worker_use_ray = True
    engine_args.worker_use_ray = True

    accelerator = "GPU"
    pg_resources = []
    tp = engine_args.tensor_parallel_size
    for i in range(tp):
        pg_resources.append({"CPU": 1, accelerator: 1})

    return VLLMDeployment.options(
        placement_group_bundles=pg_resources, placement_group_strategy="STRICT_PACK"
    ).bind(
        engine_args,
        parsed_args.response_role,
        parsed_args.lora_modules,
        parsed_args.chat_template,
    )
    # return VLLMDeployment.bind(
    #     engine_args,
    #     parsed_args.response_role,
    #     parsed_args.lora_modules,
    #     parsed_args.chat_template,
    # )

Has anyone else seen similar issue? Would really appreciate some help and guidance.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@Fmak95 Fmak95 added the bug Something isn't working label Mar 7, 2025
@vincent-4
Copy link
Contributor

vincent-4 commented Mar 7, 2025

_get_quantization_config raise ValueError( ValueError: torch.bfloat16 is not supported for quantization method awq. Supported dtypes: [torch.float16]

Guess: try updating quantization (dtype) to float16 instead of auto?

@johnny12150
Copy link

It should work when you downgrade vllm to 0.6.6.post1

ray-project/ray#51154

@pang-wu
Copy link

pang-wu commented Mar 10, 2025

@vincent-4 we tried that, same issue still. I believe what @johnny12150 mentioned is the cause, but we can't rollback to 0.6.6.post1 because it doesn't have Qwen2.5-VL support.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ray anything related with ray
Projects
Status: Ray Serve
Development

No branches or pull requests

5 participants