2424import weakref
2525from contextlib import contextmanager , nullcontext
2626from dataclasses import dataclass
27- from typing import TYPE_CHECKING , Dict , List , Optional , Union
27+ from typing import TYPE_CHECKING , Dict , List , Optional , Union , cast , get_args
2828
2929import numpy as np
3030import numpy .typing as npt
4545from vllm .model_executor .layers .fused_moe import FusedMoE
4646from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
4747from vllm .model_executor .model_loader import get_model
48- from vllm .model_executor .models .interfaces import has_step_pooler
48+ from vllm .model_executor .models .interfaces_base import (VllmModelForPooling ,
49+ is_pooling_model )
4950from vllm .multimodal import MULTIMODAL_REGISTRY
5051from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
5152from vllm .multimodal .utils import group_mm_inputs_by_modality
8889from vllm_ascend .worker .npu_input_batch import CachedRequestState , InputBatch
8990
9091if vllm_version_is ("0.9.2" ):
92+ from vllm .model_executor .models .interfaces import has_step_pooler
9193 from vllm .v1 .utils import bind_kv_cache
9294else :
95+ from vllm .pooling_params import PoolingTask
9396 from vllm .v1 .worker .utils import bind_kv_cache
9497
9598if TYPE_CHECKING :
@@ -395,13 +398,24 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
395398 for new_req_data in scheduler_output .scheduled_new_reqs :
396399 req_id = new_req_data .req_id
397400 sampling_params = new_req_data .sampling_params
401+ pooling_params = new_req_data .pooling_params
398402 if sampling_params and \
399403 sampling_params .sampling_type == SamplingType .RANDOM_SEED :
400404 generator = torch .Generator (device = self .device )
401405 generator .manual_seed (sampling_params .seed )
402406 else :
403407 generator = None
404408
409+ if not vllm_version_is ("0.9.2" ) and pooling_params :
410+ assert pooling_params .task is not None , (
411+ "You did not set `task` in the API" )
412+ model = cast (VllmModelForPooling , self .model )
413+ to_update = (model .pooler .get_pooling_updates (
414+ pooling_params .task ))
415+ assert to_update is not None , (
416+ f"{ pooling_params .task = } is not supported by the model" )
417+ to_update .apply (pooling_params )
418+
405419 self .requests [req_id ] = CachedRequestState (
406420 req_id = req_id ,
407421 prompt_token_ids = new_req_data .prompt_token_ids ,
@@ -1729,26 +1743,59 @@ def _dummy_pooler_run(
17291743
17301744 req_num_tokens = num_tokens // num_reqs
17311745
1732- dummy_metadata = PoolingMetadata (
1733- prompt_lens = torch .tensor ([h .shape [0 ] for h in hidden_states_list ],
1734- device = self .device ),
1735- prompt_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
1736- dtype = torch .int32 ,
1737- device = self .device ),
1738- pooling_params = [PoolingParams ()] * num_reqs )
1739-
1740- try :
1741- pooler_output = self .model .pooler (hidden_states = hidden_states_list ,
1742- pooling_metadata = dummy_metadata )
1743- except RuntimeError as e :
1744- if 'out of memory' in str (e ):
1745- raise RuntimeError (
1746- "NPU out of memory occurred when warming up pooler with "
1747- f"{ num_reqs } dummy requests. Please try lowering "
1748- "`max_num_seqs` or `gpu_memory_utilization` when "
1749- "initializing the engine." ) from e
1750- else :
1751- raise e
1746+ if vllm_version_is ("0.9.2" ):
1747+ dummy_metadata = PoolingMetadata (
1748+ prompt_lens = torch .tensor (
1749+ [h .shape [0 ] for h in hidden_states_list ],
1750+ device = self .device ),
1751+ prompt_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
1752+ dtype = torch .int32 ,
1753+ device = self .device ),
1754+ pooling_params = [PoolingParams ()] * num_reqs )
1755+ try :
1756+ pooler_output = self .model .pooler (
1757+ hidden_states = hidden_states_list ,
1758+ pooling_metadata = dummy_metadata )
1759+ except RuntimeError as e :
1760+ if 'out of memory' in str (e ):
1761+ raise RuntimeError (
1762+ "NPU out of memory occurred when warming up pooler with "
1763+ f"{ num_reqs } dummy requests. Please try lowering "
1764+ "`max_num_seqs` or `gpu_memory_utilization` when "
1765+ "initializing the engine." ) from e
1766+ else :
1767+ raise e
1768+ else :
1769+ model = cast (VllmModelForPooling , self .model )
1770+ dummy_task = self .get_supported_pooling_tasks ()[0 ]
1771+ dummy_pooling_params = PoolingParams (task = dummy_task )
1772+
1773+ to_update = model .pooler .get_pooling_updates (dummy_task )
1774+ assert to_update is not None
1775+ to_update .apply (dummy_pooling_params )
1776+
1777+ dummy_metadata = PoolingMetadata (
1778+ prompt_lens = torch .tensor (
1779+ [h .shape [0 ] for h in hidden_states_list ],
1780+ device = self .device ),
1781+ prompt_token_ids = torch .zeros ((num_reqs , req_num_tokens ),
1782+ dtype = torch .int32 ,
1783+ device = self .device ),
1784+ pooling_params = [dummy_pooling_params ] * num_reqs )
1785+
1786+ try :
1787+ pooler_output = model .pooler (hidden_states = hidden_states_list ,
1788+ pooling_metadata = dummy_metadata )
1789+ except RuntimeError as e :
1790+ if 'out of memory' in str (e ):
1791+ raise RuntimeError (
1792+ "NPU out of memory occurred when warming up pooler with "
1793+ f"{ num_reqs } dummy requests. Please try lowering "
1794+ "`max_num_seqs` or `gpu_memory_utilization` when "
1795+ "initializing the engine." ) from e
1796+ else :
1797+ raise e
1798+
17521799 return pooler_output
17531800
17541801 def load_model (self ) -> None :
@@ -1767,8 +1814,9 @@ def load_model(self) -> None:
17671814 QKVParallelLinear , RowParallelLinear )):
17681815 module .weight .data = torch_npu .npu_format_cast (
17691816 module .weight .data , ACL_FORMAT_FRACTAL_NZ )
1770- if has_step_pooler (self .model ):
1771- self .input_batch .logits_processing_needs_token_ids = True
1817+
1818+ if vllm_version_is ("0.9.2" ) and has_step_pooler (self .model ):
1819+ self .input_batch .logits_processing_needs_token_ids_bool = True
17721820 if self .drafter :
17731821 logger .info ("Loading drafter model..." )
17741822 if isinstance (self .drafter , EagleProposer ):
@@ -2379,3 +2427,13 @@ def select_torchair_padded_batch_size(self, batch_size: int):
23792427 if batch_size <= padded_batch_size < selected_batch_size :
23802428 selected_batch_size = padded_batch_size
23812429 return selected_batch_size
2430+
2431+ def get_supported_pooling_tasks (self ):
2432+ model = self .get_model ()
2433+ if not is_pooling_model (model ):
2434+ return []
2435+
2436+ return [
2437+ task for task in get_args (PoolingTask )
2438+ if model .pooler .get_pooling_updates (task )
2439+ ]
0 commit comments