Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error on running Qwen/Qwen2-VL-72B-Instruct-AWQ #230

Closed
markowanga opened this issue Sep 19, 2024 · 3 comments
Closed

Error on running Qwen/Qwen2-VL-72B-Instruct-AWQ #230

markowanga opened this issue Sep 19, 2024 · 3 comments
Assignees

Comments

@markowanga
Copy link

markowanga commented Sep 19, 2024

I have following error when I run vllm:

docker run \
  --gpus all \
  --ipc=host \
  --network=host \
  --rm \
  -v "/home/user/.cache/huggingface:/root/.cache/huggingface" \
  --name qwen2 \
  -it -p 8000:8000 \
  qwenllm/qwenvl:2-cu121 \
  vllm serve Qwen/Qwen2-VL-72B-Instruct-AWQ --host 0.0.0.0 --api-key=sample_pass --enforce-eager --tensor-parallel-size 2
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.8/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 236, in run_rpc_server
    server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
  File "/usr/local/lib/python3.8/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 34, in __init__
    self.engine = AsyncLLMEngine.from_engine_args(
  File "/usr/local/lib/python3.8/dist-packages/vllm/engine/async_llm_engine.py", line 735, in from_engine_args
    engine = cls(
  File "/usr/local/lib/python3.8/dist-packages/vllm/engine/async_llm_engine.py", line 615, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/vllm/engine/async_llm_engine.py", line 835, in _init_engine
    return engine_class(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/vllm/engine/async_llm_engine.py", line 262, in __init__
    super().__init__(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/vllm/engine/llm_engine.py", line 324, in __init__
    self.model_executor = executor_class(
  File "/usr/local/lib/python3.8/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 222, in __init__
    super().__init__(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/vllm/executor/distributed_gpu_executor.py", line 26, in __init__
    super().__init__(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/vllm/executor/executor_base.py", line 47, in __init__
    self._init_executor()
  File "/usr/local/lib/python3.8/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 125, in _init_executor
    self._run_workers("load_model",
  File "/usr/local/lib/python3.8/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 199, in _run_workers
    driver_worker_output = driver_worker_method(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/vllm/worker/worker.py", line 182, in load_model
    self.model_runner.load_model()
  File "/usr/local/lib/python3.8/dist-packages/vllm/worker/model_runner.py", line 995, in load_model
    self.model = get_model(model_config=self.model_config,
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
    return loader.load_model(model_config=model_config,
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/model_loader/loader.py", line 357, in load_model
    model = _initialize_model(model_config, self.load_config,
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/model_loader/loader.py", line 171, in _initialize_model
    return build_model(
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/model_loader/loader.py", line 156, in build_model
    return model_class(config=hf_config,
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/models/qwen2_vl.py", line 726, in __init__
    self.model = Qwen2Model(config, cache_config, quant_config)
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/models/qwen2.py", line 243, in __init__
    self.start_layer, self.end_layer, self.layers = make_layers(
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/models/utils.py", line 248, in make_layers
    [PPMissingLayer() for _ in range(start_layer)] + [
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/models/utils.py", line 249, in <listcomp>
    maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/models/qwen2.py", line 245, in <lambda>
    lambda prefix: Qwen2DecoderLayer(config=config,
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/models/qwen2.py", line 184, in __init__
    self.mlp = Qwen2MLP(
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/models/qwen2.py", line 69, in __init__
    self.down_proj = RowParallelLinear(intermediate_size,
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/layers/linear.py", line 974, in __init__
    self.quant_method.create_weights(
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/layers/quantization/awq_marlin.py", line 162, in create_weights
    verify_marlin_supports_shape(
  File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/layers/quantization/utils/marlin_utils.py", line 106, in verify_marlin_supports_shape
    raise ValueError(f"Weight input_size_per_partition = "
ValueError: Weight input_size_per_partition = 14784 is not divisible by min_thread_k = 128. Consider reducing tensor_parallel_size or running with --quantization gptq.
ERROR 09-19 18:35:12 api_server.py:188] RPCServer process died before responding to readiness probe

I have 2x RTX3090 -- I can run other LLMs on this configuration (on two cards)

@kq-chen kq-chen self-assigned this Sep 20, 2024
@QwertyJack
Copy link

Related to #231

@kq-chen
Copy link
Collaborator

kq-chen commented Sep 24, 2024

Based on the suggestion #231 from aabbccddwasd, we have adjusted the intermediate size to 29696 and re-quantized the model. The updated 72B AWQ/GPTQ-Int4/GPTQ-Int8 checkpoints have been uploaded to Hugging Face. To utilize the new checkpoints, please download them again from Hugging Face.

You can use the following command to perform inference on the quantized 72B model with VLLM tensor-parallel:

Server:

VLLM_WORKER_MULTIPROC_METHOD=spawn python -m vllm.entrypoints.openai.api_server \
  --served-model-name qwen2vl \
  --model Qwen/Qwen2-VL-72B-Instruct-AWQ \
  --tensor-parallel-size 4 \
  --max_num_seqs 16

Client:

curl http://localhost:8000/v1/chat/completions \
    -H "Content-Type: application/json" \
    -d '{
    "model": "qwen2vl",
    "messages": [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": [
        {"type": "image_url", "image_url": {"url": "https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png"}},
        {"type": "text", "text": "What is the text in the illustration?"}
    ]}
    ]
    }'

@kq-chen kq-chen closed this as completed Sep 24, 2024
@Pcyslist
Copy link

please try pip install ray

the reason is that vLLM supports distributed tensor-parallel inference and serving. Currently, we support Megatron-LM’s tensor parallel algorithm. We manage the distributed runtime with Ray. To run distributed inference, install Ray with:
pip install ray

in your case: --tensor-parallel-size 2 make vllm use ray library.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants