2424import weakref
2525from contextlib import contextmanager , nullcontext
2626from dataclasses import dataclass
27- from typing import TYPE_CHECKING , Dict , List , Optional , Union
27+ from typing import TYPE_CHECKING , Dict , List , Optional , Union , cast , get_args
2828
2929import numpy as np
3030import numpy .typing as npt
4545from vllm .model_executor .layers .fused_moe import FusedMoE
4646from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
4747from vllm .model_executor .model_loader import get_model
48- from vllm .model_executor .models .interfaces import has_step_pooler
48+ from vllm .model_executor .models .interfaces_base import (VllmModelForPooling ,
49+ is_pooling_model )
4950from vllm .multimodal import MULTIMODAL_REGISTRY
5051from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
5152from vllm .multimodal .utils import group_mm_inputs_by_modality
8889from vllm_ascend .worker .npu_input_batch import CachedRequestState , InputBatch
8990
9091if vllm_version_is ("0.9.2" ):
92+ from vllm .model_executor .models .interfaces import has_step_pooler
9193 from vllm .v1 .utils import bind_kv_cache
94+ PoolingTask = None
9295else :
96+ from vllm .pooling_params import PoolingTask
9397 from vllm .v1 .worker .utils import bind_kv_cache
9498
9599if TYPE_CHECKING :
@@ -402,6 +406,19 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
402406 else :
403407 generator = None
404408
409+ if not vllm_version_is ("0.9.2" ) and pooling_params :
410+ pooling_params = new_req_data .pooling_params
411+ assert pooling_params .task is not None , (
412+ "You did not set `task` in the API" )
413+
414+ model = cast (VllmModelForPooling , self .model )
415+ to_update = (model .pooler .get_pooling_updates (
416+ pooling_params .task ))
417+ assert to_update is not None , (
418+ f"{ pooling_params .task = } is not supported by the model" )
419+
420+ to_update .apply (pooling_params )
421+
405422 self .requests [req_id ] = CachedRequestState (
406423 req_id = req_id ,
407424 prompt_token_ids = new_req_data .prompt_token_ids ,
@@ -1729,26 +1746,59 @@ def _dummy_pooler_run(
17291746
17301747 req_num_tokens = num_tokens // num_reqs
17311748
1732- dummy_metadata = PoolingMetadata (
1733- prompt_lens = torch .tensor ([h .shape [0 ] for h in hidden_states_list ],
1734- device = self .device ),
1735- prompt_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
1736- dtype = torch .int32 ,
1737- device = self .device ),
1738- pooling_params = [PoolingParams ()] * num_reqs )
1739-
1740- try :
1741- pooler_output = self .model .pooler (hidden_states = hidden_states_list ,
1742- pooling_metadata = dummy_metadata )
1743- except RuntimeError as e :
1744- if 'out of memory' in str (e ):
1745- raise RuntimeError (
1746- "NPU out of memory occurred when warming up pooler with "
1747- f"{ num_reqs } dummy requests. Please try lowering "
1748- "`max_num_seqs` or `gpu_memory_utilization` when "
1749- "initializing the engine." ) from e
1750- else :
1751- raise e
1749+ if vllm_version_is ("0.9.2" ):
1750+ dummy_metadata = PoolingMetadata (
1751+ prompt_lens = torch .tensor (
1752+ [h .shape [0 ] for h in hidden_states_list ],
1753+ device = self .device ),
1754+ prompt_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
1755+ dtype = torch .int32 ,
1756+ device = self .device ),
1757+ pooling_params = [PoolingParams ()] * num_reqs )
1758+ try :
1759+ pooler_output = self .model .pooler (
1760+ hidden_states = hidden_states_list ,
1761+ pooling_metadata = dummy_metadata )
1762+ except RuntimeError as e :
1763+ if 'out of memory' in str (e ):
1764+ raise RuntimeError (
1765+ "NPU out of memory occurred when warming up pooler with "
1766+ f"{ num_reqs } dummy requests. Please try lowering "
1767+ "`max_num_seqs` or `gpu_memory_utilization` when "
1768+ "initializing the engine." ) from e
1769+ else :
1770+ raise e
1771+ else :
1772+ model = cast (VllmModelForPooling , self .model )
1773+ dummy_task = self .get_supported_pooling_tasks ()[0 ]
1774+ dummy_pooling_params = PoolingParams (task = dummy_task )
1775+
1776+ to_update = model .pooler .get_pooling_updates (dummy_task )
1777+ assert to_update is not None
1778+ to_update .apply (dummy_pooling_params )
1779+
1780+ dummy_metadata = PoolingMetadata (
1781+ prompt_lens = torch .tensor (
1782+ [h .shape [0 ] for h in hidden_states_list ],
1783+ device = self .device ),
1784+ prompt_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
1785+ dtype = torch .int32 ,
1786+ device = self .device ),
1787+ pooling_params = [dummy_pooling_params ] * num_reqs )
1788+
1789+ try :
1790+ pooler_output = model .pooler (hidden_states = hidden_states_list ,
1791+ pooling_metadata = dummy_metadata )
1792+ except RuntimeError as e :
1793+ if 'out of memory' in str (e ):
1794+ raise RuntimeError (
1795+ "NPU out of memory occurred when warming up pooler with "
1796+ f"{ num_reqs } dummy requests. Please try lowering "
1797+ "`max_num_seqs` or `gpu_memory_utilization` when "
1798+ "initializing the engine." ) from e
1799+ else :
1800+ raise e
1801+
17521802 return pooler_output
17531803
17541804 def load_model (self ) -> None :
@@ -1767,7 +1817,8 @@ def load_model(self) -> None:
17671817 QKVParallelLinear , RowParallelLinear )):
17681818 module .weight .data = torch_npu .npu_format_cast (
17691819 module .weight .data , ACL_FORMAT_FRACTAL_NZ )
1770- if has_step_pooler (self .model ):
1820+
1821+ if vllm_version_is ("0.9.2" ) and has_step_pooler (self .model ):
17711822 self .input_batch .logits_processing_needs_token_ids = True
17721823 if self .drafter :
17731824 logger .info ("Loading drafter model..." )
@@ -2379,3 +2430,13 @@ def select_torchair_padded_batch_size(self, batch_size: int):
23792430 if batch_size <= padded_batch_size < selected_batch_size :
23802431 selected_batch_size = padded_batch_size
23812432 return selected_batch_size
2433+
2434+ def get_supported_pooling_tasks (self ) -> list [PoolingTask ]:
2435+ model = self .get_model ()
2436+ if not is_pooling_model (model ):
2437+ return []
2438+
2439+ return [
2440+ task for task in get_args (PoolingTask )
2441+ if model .pooler .get_pooling_updates (task )
2442+ ]
0 commit comments