2424import weakref
2525from contextlib import contextmanager , nullcontext
2626from dataclasses import dataclass
27- from typing import TYPE_CHECKING , Dict , List , Optional , Union
27+ from typing import TYPE_CHECKING , Dict , List , Optional , Union , cast , get_args
2828
2929import numpy as np
3030import numpy .typing as npt
4545from vllm .model_executor .layers .fused_moe import FusedMoE
4646from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
4747from vllm .model_executor .model_loader import get_model
48- from vllm .model_executor .models .interfaces import has_step_pooler
48+ from vllm .model_executor .models .interfaces_base import (VllmModelForPooling ,
49+ is_pooling_model )
4950from vllm .multimodal import MULTIMODAL_REGISTRY
5051from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
5152from vllm .multimodal .utils import group_mm_inputs_by_modality
8990
9091if vllm_version_is ("0.9.2" ):
9192 from vllm .v1 .utils import bind_kv_cache
93+ from vllm .model_executor .models .interfaces import has_step_pooler
94+ PoolingTask = None
9295else :
9396 from vllm .v1 .worker .utils import bind_kv_cache
97+ from vllm .pooling_params import PoolingTask
9498
9599if TYPE_CHECKING :
96100 import xgrammar as xgr # type: ignore[import-untyped]
@@ -402,6 +406,19 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
402406 else :
403407 generator = None
404408
409+ if not vllm_version_is ("0.9.2" ) and pooling_params :
410+ pooling_params = new_req_data .pooling_params
411+ assert pooling_params .task is not None , (
412+ "You did not set `task` in the API" )
413+
414+ model = cast (VllmModelForPooling , self .model )
415+ to_update = (model .pooler .get_pooling_updates (
416+ pooling_params .task ))
417+ assert to_update is not None , (
418+ f"{ pooling_params .task = } is not supported by the model" )
419+
420+ to_update .apply (pooling_params )
421+
405422 self .requests [req_id ] = CachedRequestState (
406423 req_id = req_id ,
407424 prompt_token_ids = new_req_data .prompt_token_ids ,
@@ -1729,26 +1746,57 @@ def _dummy_pooler_run(
17291746
17301747 req_num_tokens = num_tokens // num_reqs
17311748
1732- dummy_metadata = PoolingMetadata (
1733- prompt_lens = torch .tensor ([h .shape [0 ] for h in hidden_states_list ],
1734- device = self .device ),
1735- prompt_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
1736- dtype = torch .int32 ,
1737- device = self .device ),
1738- pooling_params = [PoolingParams ()] * num_reqs )
1739-
1740- try :
1741- pooler_output = self .model .pooler (hidden_states = hidden_states_list ,
1742- pooling_metadata = dummy_metadata )
1743- except RuntimeError as e :
1744- if 'out of memory' in str (e ):
1745- raise RuntimeError (
1746- "NPU out of memory occurred when warming up pooler with "
1747- f"{ num_reqs } dummy requests. Please try lowering "
1748- "`max_num_seqs` or `gpu_memory_utilization` when "
1749- "initializing the engine." ) from e
1750- else :
1751- raise e
1749+ if vllm_version_is ("0.9.2" ):
1750+ dummy_metadata = PoolingMetadata (
1751+ prompt_lens = torch .tensor ([h .shape [0 ] for h in hidden_states_list ],
1752+ device = self .device ),
1753+ prompt_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
1754+ dtype = torch .int32 ,
1755+ device = self .device ),
1756+ pooling_params = [PoolingParams ()] * num_reqs )
1757+
1758+ try :
1759+ pooler_output = self .model .pooler (hidden_states = hidden_states_list ,
1760+ pooling_metadata = dummy_metadata )
1761+ except RuntimeError as e :
1762+ if 'out of memory' in str (e ):
1763+ raise RuntimeError (
1764+ "NPU out of memory occurred when warming up pooler with "
1765+ f"{ num_reqs } dummy requests. Please try lowering "
1766+ "`max_num_seqs` or `gpu_memory_utilization` when "
1767+ "initializing the engine." ) from e
1768+ else :
1769+ raise e
1770+ else :
1771+ model = cast (VllmModelForPooling , self .model )
1772+ dummy_task = self .get_supported_pooling_tasks ()[0 ]
1773+ dummy_pooling_params = PoolingParams (task = dummy_task )
1774+
1775+ to_update = model .pooler .get_pooling_updates (dummy_task )
1776+ assert to_update is not None
1777+ to_update .apply (dummy_pooling_params )
1778+
1779+ dummy_metadata = PoolingMetadata (
1780+ prompt_lens = torch .tensor ([h .shape [0 ] for h in hidden_states_list ],
1781+ device = self .device ),
1782+ prompt_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
1783+ dtype = torch .int32 ,
1784+ device = self .device ),
1785+ pooling_params = [dummy_pooling_params ] * num_reqs )
1786+
1787+ try :
1788+ pooler_output = model .pooler (hidden_states = hidden_states_list ,
1789+ pooling_metadata = dummy_metadata )
1790+ except RuntimeError as e :
1791+ if 'out of memory' in str (e ):
1792+ raise RuntimeError (
1793+ "NPU out of memory occurred when warming up pooler with "
1794+ f"{ num_reqs } dummy requests. Please try lowering "
1795+ "`max_num_seqs` or `gpu_memory_utilization` when "
1796+ "initializing the engine." ) from e
1797+ else :
1798+ raise e
1799+
17521800 return pooler_output
17531801
17541802 def load_model (self ) -> None :
@@ -1767,7 +1815,8 @@ def load_model(self) -> None:
17671815 QKVParallelLinear , RowParallelLinear )):
17681816 module .weight .data = torch_npu .npu_format_cast (
17691817 module .weight .data , ACL_FORMAT_FRACTAL_NZ )
1770- if has_step_pooler (self .model ):
1818+
1819+ if vllm_version_is ("0.9.2" ) and has_step_pooler (self .model ):
17711820 self .input_batch .logits_processing_needs_token_ids = True
17721821 if self .drafter :
17731822 logger .info ("Loading drafter model..." )
@@ -2379,3 +2428,13 @@ def select_torchair_padded_batch_size(self, batch_size: int):
23792428 if batch_size <= padded_batch_size < selected_batch_size :
23802429 selected_batch_size = padded_batch_size
23812430 return selected_batch_size
2431+
2432+ def get_supported_pooling_tasks (self ) -> list [PoolingTask ]:
2433+ model = self .get_model ()
2434+ if not is_pooling_model (model ):
2435+ return []
2436+
2437+ return [
2438+ task for task in get_args (PoolingTask )
2439+ if model .pooler .get_pooling_updates (task )
2440+ ]
0 commit comments