4444from vllm .outputs import (ClassificationRequestOutput , EmbeddingRequestOutput ,
4545 PoolingRequestOutput , RequestOutput ,
4646 ScoringRequestOutput )
47- from vllm .pooling_params import PoolingParams
47+ from vllm .pooling_params import PoolingParams , PoolingTask
4848from vllm .prompt_adapter .request import PromptAdapterRequest
4949from vllm .sampling_params import (BeamSearchParams , GuidedDecodingParams ,
5050 RequestOutputKind , SamplingParams )
@@ -964,6 +964,7 @@ def encode(
964964 use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
965965 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
966966 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
967+ pooling_task : PoolingTask = "encode" ,
967968 ) -> list [PoolingRequestOutput ]:
968969 ...
969970
@@ -979,6 +980,7 @@ def encode(
979980 use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
980981 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
981982 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
983+ pooling_task : PoolingTask = "encode" ,
982984 ) -> list [PoolingRequestOutput ]:
983985 ...
984986
@@ -994,6 +996,7 @@ def encode(
994996 use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
995997 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
996998 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
999+ pooling_task : PoolingTask = "encode" ,
9971000 ) -> list [PoolingRequestOutput ]:
9981001 ...
9991002
@@ -1010,6 +1013,7 @@ def encode(
10101013 use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
10111014 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
10121015 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
1016+ pooling_task : PoolingTask = "encode" ,
10131017 ) -> list [PoolingRequestOutput ]:
10141018 ...
10151019
@@ -1026,6 +1030,7 @@ def encode(
10261030 use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
10271031 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
10281032 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
1033+ pooling_task : PoolingTask = "encode" ,
10291034 ) -> list [PoolingRequestOutput ]:
10301035 ...
10311036
@@ -1040,6 +1045,7 @@ def encode(
10401045 use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
10411046 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
10421047 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
1048+ pooling_task : PoolingTask = "encode" ,
10431049 ) -> list [PoolingRequestOutput ]:
10441050 ...
10451051
@@ -1059,6 +1065,7 @@ def encode(
10591065 use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
10601066 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
10611067 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
1068+ pooling_task : PoolingTask = "encode" ,
10621069 ) -> list [PoolingRequestOutput ]:
10631070 """Apply pooling to the hidden states corresponding to the input
10641071 prompts.
@@ -1080,6 +1087,7 @@ def encode(
10801087 lora_request: LoRA request to use for generation, if any.
10811088 prompt_adapter_request: Prompt Adapter request to use for
10821089 generation, if any.
1090+ pooling_task: Override the pooling task to use.
10831091
10841092 Returns:
10851093 A list of `PoolingRequestOutput` objects containing the
@@ -1116,11 +1124,12 @@ def encode(
11161124 if pooling_params is None :
11171125 # Use default pooling params.
11181126 pooling_params = PoolingParams ()
1119- elif isinstance (pooling_params , PoolingParams ):
1120- pooling_params .verify (model_config )
1127+
1128+ if isinstance (pooling_params , PoolingParams ):
1129+ pooling_params .verify (pooling_task , model_config )
11211130 else :
11221131 for pooling_param in pooling_params :
1123- pooling_param .verify (model_config )
1132+ pooling_param .verify (pooling_task , model_config )
11241133
11251134 tokenization_kwargs = dict [str , Any ]()
11261135 _validate_truncation_size (model_config .max_model_len ,
@@ -1181,12 +1190,15 @@ def embed(
11811190 raise ValueError ("Embedding API is not supported by this model. "
11821191 "Please set `--task embed`." )
11831192
1184- items = self .encode (prompts ,
1185- truncate_prompt_tokens = truncate_prompt_tokens ,
1186- use_tqdm = use_tqdm ,
1187- pooling_params = pooling_params ,
1188- lora_request = lora_request ,
1189- prompt_adapter_request = prompt_adapter_request )
1193+ items = self .encode (
1194+ prompts ,
1195+ truncate_prompt_tokens = truncate_prompt_tokens ,
1196+ use_tqdm = use_tqdm ,
1197+ pooling_params = pooling_params ,
1198+ lora_request = lora_request ,
1199+ prompt_adapter_request = prompt_adapter_request ,
1200+ pooling_task = "embed" ,
1201+ )
11901202
11911203 return [EmbeddingRequestOutput .from_base (item ) for item in items ]
11921204
@@ -1228,10 +1240,13 @@ def classify(
12281240 "Classification API is not supported by this model. "
12291241 "Please set `--task classify`." )
12301242
1231- items = self .encode (prompts ,
1232- use_tqdm = use_tqdm ,
1233- lora_request = lora_request ,
1234- prompt_adapter_request = prompt_adapter_request )
1243+ items = self .encode (
1244+ prompts ,
1245+ use_tqdm = use_tqdm ,
1246+ lora_request = lora_request ,
1247+ prompt_adapter_request = prompt_adapter_request ,
1248+ pooling_task = "classify" ,
1249+ )
12351250
12361251 return [ClassificationRequestOutput .from_base (item ) for item in items ]
12371252
@@ -1251,7 +1266,9 @@ def _embedding_score(
12511266 truncate_prompt_tokens = truncate_prompt_tokens ,
12521267 use_tqdm = use_tqdm ,
12531268 lora_request = lora_request ,
1254- prompt_adapter_request = prompt_adapter_request )
1269+ prompt_adapter_request = prompt_adapter_request ,
1270+ pooling_task = "embed" ,
1271+ )
12551272
12561273 encoded_output_1 : list [PoolingRequestOutput ] = encoded_output [
12571274 0 :len (text_1 )]
@@ -1287,7 +1304,7 @@ def _cross_encoding_score(
12871304 if len (data_1 ) == 1 :
12881305 data_1 = data_1 * len (data_2 )
12891306
1290- pooling_params = PoolingParams (use_cross_encoder = True )
1307+ pooling_params = PoolingParams (task = "score" )
12911308 tokenization_kwargs : dict [str , Any ] = {}
12921309 _validate_truncation_size (self .llm_engine .model_config .max_model_len ,
12931310 truncate_prompt_tokens , tokenization_kwargs )
0 commit comments