Skip to content

Commit 628e9ba

Browse files
yyzxwnoooop
authored andcommitted
[FIX] Throwing an exception when the model does not support pool tasks (vllm-project#25840) (vllm-project#25855)
Signed-off-by: zxw <1020938856@qq.com> Co-authored-by: wang.yuqi <noooop@126.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 6aab772 commit 628e9ba

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

vllm/model_executor/models/adapters.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,9 @@ def as_reward_model(cls: _T) -> _T:
399399
# Lazy import
400400
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
401401

402+
from .interfaces_base import default_pooling_type
403+
404+
@default_pooling_type("ALL")
402405
class ModelForReward(_create_pooling_model_cls(cls)):
403406
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
404407
pooler_config = vllm_config.model_config.pooler_config

vllm/v1/worker/gpu_model_runner.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3622,8 +3622,28 @@ def _dummy_pooler_run(
36223622
hidden_states: torch.Tensor,
36233623
) -> PoolerOutput:
36243624
# Find the task that has the largest output for subsequent steps
3625+
supported_pooling_tasks = self.get_supported_pooling_tasks()
3626+
3627+
if not supported_pooling_tasks:
3628+
if self.scheduler_config.chunked_prefill_enabled:
3629+
raise RuntimeError(
3630+
f"Model {self.model_config.model} does not support "
3631+
"any pooling tasks with chunked prefill enabled. "
3632+
"Please add --no-enable-chunked-prefill to your "
3633+
"config or CLI args. See "
3634+
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
3635+
"to learn more."
3636+
)
3637+
else:
3638+
raise RuntimeError(
3639+
f"Model {self.model_config.model} does not support "
3640+
"any pooling tasks. See "
3641+
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
3642+
"to learn more."
3643+
)
3644+
36253645
output_size = dict[PoolingTask, float]()
3626-
for task in self.get_supported_pooling_tasks():
3646+
for task in supported_pooling_tasks:
36273647
# Run a full batch with each task to ensure none of them OOMs
36283648
output = self._dummy_pooler_run_task(hidden_states, task)
36293649
output_size[task] = sum(o.nbytes for o in output)

0 commit comments

Comments
 (0)