Skip to content

Commit fa8caec

Browse files
committed
(vllm) fix import error
1 parent dd93d89 commit fa8caec

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

vllm/entrypoints/llm.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch.nn as nn
1111
from tqdm import tqdm
1212
from typing_extensions import TypeVar, deprecated
13-
13+
import torch
1414
from vllm import envs
1515
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
1616
BeamSearchSequence, get_beam_search_score)
@@ -286,6 +286,8 @@ def generate(
286286
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
287287
guided_options_request: Optional[Union[LLMGuidedOptions,
288288
GuidedDecodingRequest]] = None,
289+
prompt_embeds: Optional[torch.Tensor] = None,
290+
priority: Optional[list[int]] = None,
289291
) -> list[RequestOutput]:
290292
...
291293

@@ -302,6 +304,8 @@ def generate(
302304
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
303305
guided_options_request: Optional[Union[LLMGuidedOptions,
304306
GuidedDecodingRequest]] = None,
307+
prompt_embeds: Optional[torch.Tensor] = None,
308+
priority: Optional[list[int]] = None,
305309
) -> list[RequestOutput]:
306310
...
307311

@@ -318,6 +322,8 @@ def generate(
318322
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
319323
guided_options_request: Optional[Union[LLMGuidedOptions,
320324
GuidedDecodingRequest]] = None,
325+
prompt_embeds: Optional[torch.Tensor] = None,
326+
priority: Optional[list[int]] = None,
321327
) -> list[RequestOutput]:
322328
...
323329

@@ -335,6 +341,8 @@ def generate(
335341
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
336342
guided_options_request: Optional[Union[LLMGuidedOptions,
337343
GuidedDecodingRequest]] = None,
344+
prompt_embeds: Optional[torch.Tensor] = None,
345+
priority: Optional[list[int]] = None,
338346
) -> list[RequestOutput]:
339347
...
340348

@@ -352,6 +360,8 @@ def generate(
352360
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
353361
guided_options_request: Optional[Union[LLMGuidedOptions,
354362
GuidedDecodingRequest]] = None,
363+
prompt_embeds: Optional[torch.Tensor] = None,
364+
priority: Optional[list[int]] = None,
355365
) -> list[RequestOutput]:
356366
...
357367

@@ -367,6 +377,8 @@ def generate(
367377
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
368378
guided_options_request: Optional[Union[LLMGuidedOptions,
369379
GuidedDecodingRequest]] = None,
380+
prompt_embeds: Optional[torch.Tensor] = None,
381+
priority: Optional[list[int]] = None,
370382
) -> list[RequestOutput]:
371383
...
372384

@@ -381,7 +393,7 @@ def generate(
381393
Optional[Union[str, list[str]]]] = None,
382394
sampling_params: Optional[Union[SamplingParams,
383395
Sequence[SamplingParams]]] = None,
384-
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
396+
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
385397
prompt_embeds: Optional[torch.Tensor] = None,
386398
use_tqdm: bool = True,
387399
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
@@ -405,10 +417,15 @@ def generate(
405417
When it is a single value, it is applied to every prompt.
406418
When it is a list, the list must have the same length as the
407419
prompts and it is paired one by one with the prompt.
420+
prompt_token_ids: DEPRECATED. Token IDs for the prompts. If provided,
421+
the `prompts` will be ignored.
422+
prompt_embeds: Optional tensor of prompt embeddings to use instead of
423+
text prompts.
408424
use_tqdm: Whether to use tqdm to display the progress bar.
409425
lora_request: LoRA request to use for generation, if any.
410426
prompt_adapter_request: Prompt Adapter request to use for
411427
generation, if any.
428+
guided_options_request: Options for guided decoding, if any.
412429
priority: The priority of the requests, if any.
413430
Only applicable when priority scheduling policy is enabled.
414431
@@ -442,13 +459,13 @@ def generate(
442459
parsed_prompts = self._convert_v1_inputs(
443460
prompts=cast(Optional[Union[str, list[str]]], prompts),
444461
prompt_token_ids=prompt_token_ids,
462+
prompt_embeds=prompt_embeds,
445463
)
446464
else:
447465
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
448466
prompts)
449-
450-
if prompt_embeds is not None:
451-
parsed_prompts.prompt_embeds = prompt_embeds
467+
if prompt_embeds is not None and hasattr(parsed_prompts, "prompt_embeds"):
468+
parsed_prompts.prompt_embeds = prompt_embeds
452469

453470
if isinstance(guided_options_request, dict):
454471
if len(guided_options_request) > 1:
@@ -1229,8 +1246,8 @@ def wake_up(self):
12291246
# LEGACY
12301247
def _convert_v1_inputs(
12311248
self,
1232-
prompts: Optional[Union[str, List[str]]],
1233-
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
1249+
prompts: Optional[Union[str, list[str]]],
1250+
prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
12341251
prompt_embeds: Optional[torch.Tensor] = None,
12351252
):
12361253
# skip_tokenizer_init is now checked in engine
@@ -1269,6 +1286,13 @@ def _convert_v1_inputs(
12691286

12701287
parsed_prompts.append(item)
12711288

1289+
# Handle prompt_embeds if provided
1290+
if prompt_embeds is not None:
1291+
# Assuming prompt_embeds is a tensor that can be assigned to the first prompt
1292+
# This might need adjustment based on how prompt_embeds is actually used
1293+
if len(parsed_prompts) > 0 and hasattr(parsed_prompts[0], "prompt_embeds"):
1294+
parsed_prompts[0].prompt_embeds = prompt_embeds
1295+
12721296
return parsed_prompts
12731297

12741298
def _validate_and_add_requests(
@@ -1403,3 +1427,4 @@ def _run_engine(
14031427
# This is necessary because some requests may be finished earlier than
14041428
# its previous requests.
14051429
return sorted(outputs, key=lambda x: int(x.request_id))
1430+

vllm/model_executor/models/qwen2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,7 @@ def forward(
461461
inputs_embeds: Optional[torch.Tensor] = None,
462462
) -> Union[torch.Tensor, IntermediateTensors]:
463463

464-
hidden_states = self.model(input_ids, positions, kv_caches,
465-
attn_metadata, intermediate_tensors,
466-
inputs_embeds, self.lm_head.bias)
464+
hidden_states = self.model(input_ids, positions,intermediate_tensors, inputs_embeds, self.lm_head.bias)
467465
return hidden_states
468466

469467
def compute_logits(

0 commit comments

Comments
 (0)