Skip to content

Commit dd93d89

Browse files
committed
improve embedding input
1 parent 8bcec12 commit dd93d89

File tree

8 files changed

+27
-48
lines changed

8 files changed

+27
-48
lines changed

vllm/attention/backends/flash_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,6 @@ def forward(
732732
prefill_output = output[:num_prefill_query_tokens]
733733
assert query.shape[0] == num_prefill_query_tokens
734734
assert decode_query.shape[0] == num_decode_query_tokens
735-
736735
if prefill_meta := attn_metadata.prefill_metadata:
737736
# Prompt run.
738737
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None

vllm/engine/llm_engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -753,10 +753,9 @@ def add_request(
753753
if arrival_time is None:
754754
arrival_time = time.time()
755755

756-
if isinstance(prompt, dict) and prompt.get("prompt_embeds",
757-
None) is not None:
756+
if isinstance(prompt, dict) and prompt.get("prompt_embeds", None) is not None:
758757
if not prompt.get("prompt_token_ids", None):
759-
prompt["prompt_token_ids"] = [0] * len(prompt["prompt_embeds"])
758+
prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0]
760759

761760
if self.tokenizer is not None:
762761
self._validate_token_prompt(

vllm/entrypoints/llm.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99
import cloudpickle
1010
import torch.nn as nn
1111
from tqdm import tqdm
12-
<<<<<<< HEAD
1312
from typing_extensions import TypeVar, deprecated
14-
=======
15-
from typing_extensions import deprecated
16-
import torch
17-
>>>>>>> 0d69ec2f ((vllm) add input embedding)
1813

1914
from vllm import envs
2015
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
@@ -386,12 +381,8 @@ def generate(
386381
Optional[Union[str, list[str]]]] = None,
387382
sampling_params: Optional[Union[SamplingParams,
388383
Sequence[SamplingParams]]] = None,
389-
<<<<<<< HEAD
390-
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
391-
=======
392384
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
393385
prompt_embeds: Optional[torch.Tensor] = None,
394-
>>>>>>> 0d69ec2f ((vllm) add input embedding)
395386
use_tqdm: bool = True,
396387
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
397388
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -1238,14 +1229,9 @@ def wake_up(self):
12381229
# LEGACY
12391230
def _convert_v1_inputs(
12401231
self,
1241-
<<<<<<< HEAD
1242-
prompts: Optional[Union[str, list[str]]],
1243-
prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1244-
=======
12451232
prompts: Optional[Union[str, List[str]]],
12461233
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
12471234
prompt_embeds: Optional[torch.Tensor] = None,
1248-
>>>>>>> 0d69ec2f ((vllm) add input embedding)
12491235
):
12501236
# skip_tokenizer_init is now checked in engine
12511237

vllm/inputs/data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,6 @@ class TokenInputs(TypedDict):
145145
prompt_token_ids: List[int]
146146
"""The token IDs of the prompt."""
147147

148-
prompt_embeds: NotRequired[torch.Tensor]
149-
"""The embeddings of the prompt, if available."""
150-
151148
token_type_ids: NotRequired[List[int]]
152149
"""The token type IDs of the prompt."""
153150

@@ -156,6 +153,9 @@ class TokenInputs(TypedDict):
156153
The original prompt text corresponding to the token IDs, if available.
157154
"""
158155

156+
prompt_embeds: NotRequired[torch.Tensor]
157+
"""The embeddings of the prompt, if available."""
158+
159159
multi_modal_data: NotRequired["MultiModalDataDict"]
160160
"""
161161
Optional multi-modal data to pass to the model,

vllm/inputs/preprocess.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def _prompt_to_llm_inputs(
360360

361361
return token_inputs(
362362
prompt_token_ids=prompt_token_ids,
363-
prompt_embeds=tokens_content.get('prompt_embeds'),
363+
prompt_embeds=tokens_content.get("prompt_embeds"),
364364
token_type_ids=token_type_ids,
365365
multi_modal_data=multi_modal_data,
366366
mm_processor_kwargs=mm_processor_kwargs,
@@ -390,7 +390,7 @@ def _prompt_to_llm_inputs(
390390
return token_inputs(
391391
prompt=prompt_text,
392392
prompt_token_ids=prompt_token_ids,
393-
prompt_embeds=text_content.get('prompt_embeds'),
393+
prompt_embeds=text_content.get("prompt_embeds"),
394394
multi_modal_data=multi_modal_data,
395395
mm_processor_kwargs=mm_processor_kwargs,
396396
)
@@ -436,7 +436,7 @@ async def _prompt_to_llm_inputs_async(
436436

437437
return token_inputs(
438438
prompt_token_ids=prompt_token_ids,
439-
prompt_embeds=tokens_content.get('prompt_embeds'),
439+
prompt_embeds=tokens_content.get("prompt_embeds"),
440440
multi_modal_data=multi_modal_data,
441441
mm_processor_kwargs=mm_processor_kwargs,
442442
)
@@ -465,7 +465,7 @@ async def _prompt_to_llm_inputs_async(
465465
return token_inputs(
466466
prompt=prompt_text,
467467
prompt_token_ids=prompt_token_ids,
468-
prompt_embeds=text_content.get('prompt_embeds'),
468+
prompt_embeds=tokens_content.get("prompt_embeds"),
469469
multi_modal_data=multi_modal_data,
470470
mm_processor_kwargs=mm_processor_kwargs,
471471
)

vllm/model_executor/models/qwen2.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,14 +460,10 @@ def forward(
460460
intermediate_tensors: Optional[IntermediateTensors] = None,
461461
inputs_embeds: Optional[torch.Tensor] = None,
462462
) -> Union[torch.Tensor, IntermediateTensors]:
463-
<<<<<<< HEAD
464-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
465-
inputs_embeds)
466-
=======
463+
467464
hidden_states = self.model(input_ids, positions, kv_caches,
468465
attn_metadata, intermediate_tensors,
469466
inputs_embeds, self.lm_head.bias)
470-
>>>>>>> 0d69ec2f ((vllm) add input embedding)
471467
return hidden_states
472468

473469
def compute_logits(

vllm/sequence.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,6 @@ def output_token_ids(self,
264264
new_output_token_ids)
265265
self._update_cached_all_tokens()
266266

267-
@property
268-
def prompt_embeds(self) -> Optional[torch.Tensor]:
269-
return self._prompt_embeds
270-
271-
@prompt_embeds.setter
272-
def prompt_embeds(self, prompt_embeds: Optional[torch.Tensor]) -> None:
273-
self._prompt_embeds = prompt_embeds
274-
275267
@property
276268
def output_token_ids_array(self) -> array:
277269
"""Return the prompt token ids in array type.
@@ -281,6 +273,14 @@ def output_token_ids_array(self) -> array:
281273
"""
282274
assert isinstance(self._output_token_ids, array)
283275
return self._output_token_ids
276+
277+
@property
278+
def prompt_embeds(self) -> Optional[torch.Tensor]:
279+
return self._prompt_embeds
280+
281+
@prompt_embeds.setter
282+
def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
283+
self._prompt_embeds = prompt_embeds
284284

285285
@property
286286
def mrope_position_delta(self) -> Optional[int]:
@@ -389,8 +389,8 @@ def stage(self) -> SequenceStage:
389389
def __repr__(self) -> str:
390390
return (f"SequenceData("
391391
f"prompt_token_ids={self._prompt_token_ids}, "
392+
f"prompt_embeds={getattr(self._prompt_embeds, 'shape', None)}, "
392393
f"output_token_ids={self.output_token_ids}, "
393-
f"prompt_embeds={getattr(self.prompt_embeds, 'shape', None)}, "
394394
f"cumulative_logprob={self.cumulative_logprob}, "
395395
f"get_num_computed_tokens={self.get_num_computed_tokens()})")
396396

vllm/worker/model_runner.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,9 @@ def __init__(
365365

366366
else:
367367
self.input_tokens = input_tokens or []
368-
self.inputs_embeds = (inputs_embeds
369-
if inputs_embeds is not None else None)
368+
self.inputs_embeds = (
369+
inputs_embeds if inputs_embeds is not None else None
370+
)
370371
self.input_positions = input_positions or []
371372
self.token_types = token_types or []
372373
self.mrope_input_positions = mrope_input_positions or None
@@ -544,12 +545,12 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
544545

545546
# Compute tokens.
546547
tokens = seq_data.get_token_ids()[context_len:seq_len]
547-
if seq_data.prompt_embeds is not None and seq_data.get_output_len(
548-
) == 0:
549-
prompt_embeds = seq_data.prompt_embeds[context_len:seq_len]
548+
if seq_data.prompt_embeds is not None and seq_data.get_output_len() == 0:
549+
prompt_embeds = seq_data.prompt_embeds[context_len:seq_len]
550550
else:
551-
seq_data.prompt_embeds = None
551+
seq_data.prompt_embeds = None # release memory
552552
prompt_embeds = None
553+
553554
token_types = seq_group_metadata.token_type_ids
554555

555556
inter_data.seq_lens[seq_idx] = seq_len
@@ -870,9 +871,7 @@ def build(self) -> ModelInputForGPU:
870871
for cur_token_types in inter_data.token_types:
871872
token_types.extend(cur_token_types)
872873
if inter_data.inputs_embeds is not None:
873-
inputs_embeds.append(
874-
inter_data.inputs_embeds.to(self.runner.device))
875-
874+
inputs_embeds.append(inter_data.inputs_embeds.to(self.runner.device))
876875
if len(inputs_embeds) == 0:
877876
inputs_embeds = None
878877
elif len(inputs_embeds) == 1:

0 commit comments

Comments
 (0)