1010import torch .nn as nn
1111from tqdm import tqdm
1212from typing_extensions import TypeVar , deprecated
13-
13+ import torch
1414from vllm import envs
1515from vllm .beam_search import (BeamSearchInstance , BeamSearchOutput ,
1616 BeamSearchSequence , get_beam_search_score )
@@ -286,6 +286,8 @@ def generate(
286286 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
287287 guided_options_request : Optional [Union [LLMGuidedOptions ,
288288 GuidedDecodingRequest ]] = None ,
289+ prompt_embeds : Optional [torch .Tensor ] = None ,
290+ priority : Optional [list [int ]] = None ,
289291 ) -> list [RequestOutput ]:
290292 ...
291293
@@ -302,6 +304,8 @@ def generate(
302304 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
303305 guided_options_request : Optional [Union [LLMGuidedOptions ,
304306 GuidedDecodingRequest ]] = None ,
307+ prompt_embeds : Optional [torch .Tensor ] = None ,
308+ priority : Optional [list [int ]] = None ,
305309 ) -> list [RequestOutput ]:
306310 ...
307311
@@ -318,6 +322,8 @@ def generate(
318322 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
319323 guided_options_request : Optional [Union [LLMGuidedOptions ,
320324 GuidedDecodingRequest ]] = None ,
325+ prompt_embeds : Optional [torch .Tensor ] = None ,
326+ priority : Optional [list [int ]] = None ,
321327 ) -> list [RequestOutput ]:
322328 ...
323329
@@ -335,6 +341,8 @@ def generate(
335341 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
336342 guided_options_request : Optional [Union [LLMGuidedOptions ,
337343 GuidedDecodingRequest ]] = None ,
344+ prompt_embeds : Optional [torch .Tensor ] = None ,
345+ priority : Optional [list [int ]] = None ,
338346 ) -> list [RequestOutput ]:
339347 ...
340348
@@ -352,6 +360,8 @@ def generate(
352360 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
353361 guided_options_request : Optional [Union [LLMGuidedOptions ,
354362 GuidedDecodingRequest ]] = None ,
363+ prompt_embeds : Optional [torch .Tensor ] = None ,
364+ priority : Optional [list [int ]] = None ,
355365 ) -> list [RequestOutput ]:
356366 ...
357367
@@ -367,6 +377,8 @@ def generate(
367377 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
368378 guided_options_request : Optional [Union [LLMGuidedOptions ,
369379 GuidedDecodingRequest ]] = None ,
380+ prompt_embeds : Optional [torch .Tensor ] = None ,
381+ priority : Optional [list [int ]] = None ,
370382 ) -> list [RequestOutput ]:
371383 ...
372384
@@ -381,7 +393,7 @@ def generate(
381393 Optional [Union [str , list [str ]]]] = None ,
382394 sampling_params : Optional [Union [SamplingParams ,
383395 Sequence [SamplingParams ]]] = None ,
384- prompt_token_ids : Optional [Union [List [int ], List [ List [int ]]]] = None ,
396+ prompt_token_ids : Optional [Union [list [int ], list [ list [int ]]]] = None ,
385397 prompt_embeds : Optional [torch .Tensor ] = None ,
386398 use_tqdm : bool = True ,
387399 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
@@ -405,10 +417,15 @@ def generate(
405417 When it is a single value, it is applied to every prompt.
406418 When it is a list, the list must have the same length as the
407419 prompts and it is paired one by one with the prompt.
420+ prompt_token_ids: DEPRECATED. Token IDs for the prompts. If provided,
421+ the `prompts` will be ignored.
422+ prompt_embeds: Optional tensor of prompt embeddings to use instead of
423+ text prompts.
408424 use_tqdm: Whether to use tqdm to display the progress bar.
409425 lora_request: LoRA request to use for generation, if any.
410426 prompt_adapter_request: Prompt Adapter request to use for
411427 generation, if any.
428+ guided_options_request: Options for guided decoding, if any.
412429 priority: The priority of the requests, if any.
413430 Only applicable when priority scheduling policy is enabled.
414431
@@ -442,13 +459,13 @@ def generate(
442459 parsed_prompts = self ._convert_v1_inputs (
443460 prompts = cast (Optional [Union [str , list [str ]]], prompts ),
444461 prompt_token_ids = prompt_token_ids ,
462+ prompt_embeds = prompt_embeds ,
445463 )
446464 else :
447465 parsed_prompts = cast (Union [PromptType , Sequence [PromptType ]],
448466 prompts )
449-
450- if prompt_embeds is not None :
451- parsed_prompts .prompt_embeds = prompt_embeds
467+ if prompt_embeds is not None and hasattr (parsed_prompts , "prompt_embeds" ):
468+ parsed_prompts .prompt_embeds = prompt_embeds
452469
453470 if isinstance (guided_options_request , dict ):
454471 if len (guided_options_request ) > 1 :
@@ -1229,8 +1246,8 @@ def wake_up(self):
12291246 # LEGACY
12301247 def _convert_v1_inputs (
12311248 self ,
1232- prompts : Optional [Union [str , List [str ]]],
1233- prompt_token_ids : Optional [Union [List [int ], List [ List [int ]]]],
1249+ prompts : Optional [Union [str , list [str ]]],
1250+ prompt_token_ids : Optional [Union [list [int ], list [ list [int ]]]],
12341251 prompt_embeds : Optional [torch .Tensor ] = None ,
12351252 ):
12361253 # skip_tokenizer_init is now checked in engine
@@ -1269,6 +1286,13 @@ def _convert_v1_inputs(
12691286
12701287 parsed_prompts .append (item )
12711288
1289+ # Handle prompt_embeds if provided
1290+ if prompt_embeds is not None :
1291+ # Assuming prompt_embeds is a tensor that can be assigned to the first prompt
1292+ # This might need adjustment based on how prompt_embeds is actually used
1293+ if len (parsed_prompts ) > 0 and hasattr (parsed_prompts [0 ], "prompt_embeds" ):
1294+ parsed_prompts [0 ].prompt_embeds = prompt_embeds
1295+
12721296 return parsed_prompts
12731297
12741298 def _validate_and_add_requests (
@@ -1403,3 +1427,4 @@ def _run_engine(
14031427 # This is necessary because some requests may be finished earlier than
14041428 # its previous requests.
14051429 return sorted (outputs , key = lambda x : int (x .request_id ))
1430+
0 commit comments