Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] generate from input embeds #6869

Open
wants to merge 126 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
ed8eb22
feat: add support for generate from prompt embeddings
Nan2018 Jul 27, 2024
7ccbf01
Merge remote-tracking branch 'vllm/main' into feature-input-embeds
Nan2018 Jul 27, 2024
7c663a1
Merge branch 'vllm-project:main' into feature-input-embeds
Nan2018 Jul 27, 2024
3bd6423
fix ci errors
Nan2018 Jul 27, 2024
86777a7
Merge remote-tracking branch 'vllm/main' into feature-input-embeds
Nan2018 Aug 5, 2024
9a9d406
fix: tensor parallel
Nan2018 Aug 6, 2024
48fc6a8
style: yapf
Nan2018 Aug 6, 2024
737d01b
fix: model_runner in a WorkerWrapper
Nan2018 Aug 6, 2024
4b99109
Merge remote-tracking branch 'vllm/main' into feature-input-embeds
Nan2018 Aug 7, 2024
03344ab
fix: spec decoding model
Nan2018 Aug 7, 2024
40038e0
fix: ruff
Nan2018 Aug 7, 2024
b00bbc7
fix: move param prompt_embeds_shape to the last of RequestOutput
Nan2018 Aug 8, 2024
3c1a6fa
feat: all *ForCausalLM models support inputs_embeds
Nan2018 Aug 8, 2024
c454647
fix: format
Nan2018 Aug 8, 2024
535ad97
fix: format
Nan2018 Aug 8, 2024
c05a8ff
fix: format
Nan2018 Aug 8, 2024
2b50573
Merge remote-tracking branch 'vllm/main' into feature-input-embeds
Nan2018 Sep 4, 2024
7ddc863
Merge remote-tracking branch 'vllm/main' into feature-input-embeds
Nan2018 Sep 4, 2024
d83915b
fix: engines
Nan2018 Sep 4, 2024
dfd9301
fix: format
Nan2018 Sep 4, 2024
fd455eb
fix: format
Nan2018 Sep 4, 2024
03fcf3b
fix: sequence property and embeddings for phi3v
Nan2018 Sep 5, 2024
40f485b
Merge remote-tracking branch 'vllm/main' into feature-input-embeds
Nan2018 Sep 5, 2024
a72930f
fix: ultravox
Nan2018 Sep 6, 2024
5e3eec9
refactor: rename parameter
Nan2018 Sep 6, 2024
beba406
Merge remote-tracking branch 'vllm/main' into feature-input-embeds
Nan2018 Sep 11, 2024
5cde3b4
Merge remote-tracking branch 'vllm/main' into feature-input-embeds
Nan2018 Sep 11, 2024
2b39026
refactor: supports_input_embeds
Nan2018 Sep 11, 2024
29525e1
Merge remote-tracking branch 'vllm/main' into feature-input-embeds
Nan2018 Sep 12, 2024
49fe3f7
feat: inputs_embeds for new models
Nan2018 Sep 12, 2024
40ca516
Merge remote-tracking branch 'vllm/main' into feature-input-embeds
Nan2018 Sep 17, 2024
56b9ac5
Fix typing
DarkLight1337 Sep 19, 2024
decb8ab
Support embeds in minicpm
DarkLight1337 Sep 19, 2024
ee41cb7
Fix typing 2
DarkLight1337 Sep 19, 2024
fd58d4b
Disable `inputs_embeds` for multimodal models as it conflicts with mu…
DarkLight1337 Sep 19, 2024
5849509
Reformat
DarkLight1337 Sep 19, 2024
f03d980
Merge branch 'main' into feature-input-embeds
DarkLight1337 Sep 19, 2024
23a4876
Optimize
DarkLight1337 Sep 19, 2024
fbf3f10
Fix unbound variables
DarkLight1337 Sep 19, 2024
9822652
Cleanup
DarkLight1337 Sep 19, 2024
53962c4
Cleanup 2
DarkLight1337 Sep 19, 2024
b8137aa
Fix unbound `prompt_embeds` in validation and clean it up
DarkLight1337 Sep 19, 2024
2cf3b4b
Indent
DarkLight1337 Sep 19, 2024
716b64a
Fix unbound local and cleanup 2
DarkLight1337 Sep 19, 2024
7dd3d86
fix: gemma
Nan2018 Sep 19, 2024
8059014
Have two distinct `SequenceData` classes, one for tokens and one for …
DarkLight1337 Sep 20, 2024
7f8ed8c
Rename `PromptInputs` to `PromptType`
DarkLight1337 Sep 20, 2024
89b5753
Fix type errors
DarkLight1337 Sep 20, 2024
8e91af3
format
DarkLight1337 Sep 20, 2024
31e2a1b
Rename `LLMInputs` to `DecoderOnlyInputs` and fix input processing fo…
DarkLight1337 Sep 20, 2024
60bc7b5
Fix error on import
DarkLight1337 Sep 20, 2024
a8483a4
Revert class splitting
DarkLight1337 Sep 20, 2024
f451192
Fix init error for embeds
DarkLight1337 Sep 20, 2024
3d436e9
Be a bit more efficient
DarkLight1337 Sep 20, 2024
dfecf4b
Rename
DarkLight1337 Sep 20, 2024
29b9d5c
Rename 2
DarkLight1337 Sep 20, 2024
d5ec13a
Fix encoder-decoder test prompts
DarkLight1337 Sep 20, 2024
b5632f9
Fix validation for encoder-decoder models
DarkLight1337 Sep 20, 2024
881e6da
Merge branch 'main' into feature-input-embeds
DarkLight1337 Sep 20, 2024
3da5ad6
Fix naming
DarkLight1337 Sep 20, 2024
94ace38
Rename `PromptInputs` to `PromptType`, and `inputs` to `prompt`
DarkLight1337 Sep 20, 2024
065a304
Remove unnecessary comments
DarkLight1337 Sep 20, 2024
affbd12
Merge branch 'rename-prompt' into feature-input-embeds
DarkLight1337 Sep 20, 2024
2e3cb4a
Merge branch 'main' into feature-input-embeds
DarkLight1337 Sep 21, 2024
e5a771a
Merge branch 'main' into feature-input-embeds
DarkLight1337 Sep 21, 2024
dfc108d
Merge branch 'main' into feature-input-embeds
DarkLight1337 Sep 22, 2024
741d4c1
Fix import
DarkLight1337 Sep 22, 2024
9350022
Format
DarkLight1337 Sep 22, 2024
d29d563
Merge branch 'main' into feature-input-embeds
DarkLight1337 Sep 28, 2024
1de2b99
Add validation for embedding inputs
DarkLight1337 Sep 28, 2024
9124115
Update for mllama
DarkLight1337 Sep 28, 2024
6c366eb
format
DarkLight1337 Sep 28, 2024
280596c
Merge branch 'main' into feature-input-embeds
DarkLight1337 Oct 1, 2024
ad6c364
Fix failing tests
DarkLight1337 Oct 1, 2024
529e91d
Merge branch 'main' into feature-input-embeds
DarkLight1337 Oct 5, 2024
13bbd02
format
DarkLight1337 Oct 5, 2024
0946dc6
Merge branch 'main' into feature-input-embeds
DarkLight1337 Oct 7, 2024
4000b90
Improve type annotations
DarkLight1337 Oct 7, 2024
7a26dd0
Merge branch 'main' into feature-input-embeds
DarkLight1337 Oct 16, 2024
1bafe1b
Update
DarkLight1337 Oct 16, 2024
0d6331d
Merge branch 'main' into feature-input-embeds
DarkLight1337 Oct 22, 2024
2da9eea
Fix KeyError; debug
DarkLight1337 Oct 22, 2024
0c872b3
Make encoder-decoder inputs a composed structure
DarkLight1337 Oct 23, 2024
9287a1b
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Oct 23, 2024
70db2db
Merge branch 'refactor-preprocessing' into feature-input-embeds
DarkLight1337 Oct 23, 2024
fa5ad17
Rename
DarkLight1337 Oct 23, 2024
44fd058
Fix type error
DarkLight1337 Oct 23, 2024
3ce9236
Merge branch 'refactor-preprocessing' into feature-input-embeds
DarkLight1337 Oct 23, 2024
d167df3
Fix bad merge
DarkLight1337 Oct 23, 2024
b73a345
Fix test
DarkLight1337 Oct 23, 2024
fa968b5
Fix llama-3.2
DarkLight1337 Oct 23, 2024
6e0934c
Merge branch 'refactor-preprocessing' into feature-input-embeds
DarkLight1337 Oct 23, 2024
5ccc390
Fix wrong variable
DarkLight1337 Oct 23, 2024
2fe159c
Impl get_inputs_embeds
DarkLight1337 Oct 23, 2024
7986553
Fix tests
DarkLight1337 Oct 23, 2024
4c072ca
format
DarkLight1337 Oct 23, 2024
3b22bbc
Don't use `prompt_embeds` in `get_len`
DarkLight1337 Oct 23, 2024
906ee1e
Remove force_bos
DarkLight1337 Oct 24, 2024
005ad95
Add example output
DarkLight1337 Oct 24, 2024
a5f0c16
format
DarkLight1337 Oct 24, 2024
6ab44e4
Fix
DarkLight1337 Oct 24, 2024
21be11f
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Oct 29, 2024
1f927d2
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Oct 31, 2024
760db05
Fix merge
DarkLight1337 Oct 31, 2024
acb8e6f
Update mllama processing
DarkLight1337 Oct 31, 2024
3bed519
Fix line
DarkLight1337 Oct 31, 2024
ea861e0
format
DarkLight1337 Oct 31, 2024
f654421
Avoid repeated lookups
DarkLight1337 Oct 31, 2024
594794e
Remove unused import
DarkLight1337 Oct 31, 2024
08ea824
Fix mypy
DarkLight1337 Oct 31, 2024
b622f41
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Oct 31, 2024
800960d
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Oct 31, 2024
283bc2c
Fix merge
DarkLight1337 Oct 31, 2024
e8169ea
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Nov 2, 2024
61bf1d1
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Nov 3, 2024
b45cdc9
Fix missing import
DarkLight1337 Nov 3, 2024
4d33b1e
Improve error message
DarkLight1337 Nov 3, 2024
0a549e5
Add missing export
DarkLight1337 Nov 3, 2024
f741a75
Improve error message.
DarkLight1337 Nov 3, 2024
cd231fa
Format
DarkLight1337 Nov 3, 2024
e4f3a93
Merge branch 'refactor-preprocessing' into feature-input-embeds
DarkLight1337 Nov 4, 2024
c61d246
Merge branch 'main' into feature-input-embeds
DarkLight1337 Nov 5, 2024
b8aaa8e
Merge branch 'main' into feature-input-embeds
DarkLight1337 Nov 6, 2024
c8fc1fe
Update `get_inputs_embeds` to be compatible with `torch.compile`
DarkLight1337 Nov 6, 2024
85124f7
Merge branch 'main' into feature-input-embeds
DarkLight1337 Nov 6, 2024
a7429ad
Merge branch 'main' into feature-input-embeds
DarkLight1337 Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.inputs import (EmbedsPrompt, ExplicitEncoderDecoderPrompt,
TextPrompt, to_enc_dec_tuple_list,
zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
Expand Down Expand Up @@ -654,21 +655,26 @@ def __init__(

def get_inputs(
self,
prompts: List[str],
prompts_or_prompt_embeds: Union[List[str], List[torch.Tensor]],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[TextPrompt]:
) -> List[Union[TextPrompt, EmbedsPrompt]]:
if images is not None:
assert len(prompts) == len(images)
assert len(prompts_or_prompt_embeds) == len(images)

if videos is not None:
assert len(prompts) == len(videos)
assert len(prompts_or_prompt_embeds) == len(videos)

if audios is not None:
assert len(prompts) == len(audios)
assert len(prompts_or_prompt_embeds) == len(audios)

inputs = [
EmbedsPrompt(prompt_embeds=prompt) if isinstance(
prompt, torch.Tensor) else TextPrompt(prompt=prompt)
for prompt in prompts_or_prompt_embeds
]

inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
if image is not None:
Expand Down Expand Up @@ -696,13 +702,13 @@ def classify(self, prompts: List[str]) -> List[str]:

def generate(
self,
prompts: List[str],
prompts_or_prompt_embeds: Union[List[str], List[torch.Tensor]],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
inputs = self.get_inputs(prompts_or_prompt_embeds,
images=images,
videos=videos,
audios=audios)
Expand All @@ -720,7 +726,7 @@ def generate(
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
req_sample_output_strs.append((prompt_str or "") + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs

Expand Down Expand Up @@ -785,14 +791,14 @@ def generate_encoder_decoder_w_logprobs(

def generate_greedy(
self,
prompts: List[str],
prompts_or_prompt_embeds: Union[List[str], List[torch.Tensor]],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts,
outputs = self.generate(prompts_or_prompt_embeds,
greedy_params,
images=images,
videos=videos,
Expand Down
20 changes: 20 additions & 0 deletions tests/models/decoder_only/language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,22 @@ def test_models(
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

prompt_embeds = []
prompt_token_ids = []
for prompt in example_prompts:
token_ids = hf_model.tokenizer(prompt,
return_tensors="pt").input_ids.to(
hf_model.model.device)
prompt_token_ids.append(token_ids)
prompt_embeds.append(
hf_model.model.get_input_embeddings()(token_ids).squeeze(0))

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs(
prompt_embeds, max_tokens, num_logprobs)

# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
Expand All @@ -68,3 +81,10 @@ def test_models(
name_0="hf",
name_1="vllm",
)

check_logprobs_close(
outputs_0_lst=vllm_outputs,
outputs_1_lst=vllm_outputs_from_embeds,
name_0="vllm",
name_1="vllm_from_embeds",
)
109 changes: 85 additions & 24 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from typing import List

import pytest
Expand All @@ -22,8 +23,9 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
return model_runner


@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0))
def test_prepare_prompt(batch_size, prompt_embeds_ratio):
model_runner = _create_model_runner(
"facebook/opt-125m",
max_num_batched_tokens=100000,
Expand All @@ -34,11 +36,19 @@ def test_prepare_prompt(batch_size):
seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
input_embeds_len = 0
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData.from_seqs(range(seq_len))
if random.random() < prompt_embeds_ratio:
seq_data = SequenceData.from_seqs(
range(seq_len),
prompt_embeds=torch.rand(seq_len, 10),
)
input_embeds_len += seq_len
else:
seq_data = SequenceData.from_seqs(range(seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand All @@ -59,6 +69,8 @@ def test_prepare_prompt(batch_size):
seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
input_embeds = model_input.input_embeds
input_embeds_masks = model_input.input_embeds_masks
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping
Expand Down Expand Up @@ -112,7 +124,12 @@ def test_prepare_prompt(batch_size):

assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(seq_lens)
torch.testing.assert_close(input_tokens, input_positions)
assert len(input_embeds_masks) == sum(seq_lens)
if input_embeds_len == 0:
torch.testing.assert_close(input_tokens, input_positions)
assert input_embeds is None
else:
assert len(input_embeds) == input_embeds_len

sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
Expand All @@ -136,8 +153,9 @@ def test_prepare_prompt(batch_size):
torch.testing.assert_close(actual, expected)


@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0))
def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio):
model_runner = _create_model_runner(
"facebook/opt-125m",
seed=0,
Expand All @@ -151,11 +169,19 @@ def test_prepare_decode_cuda_graph(batch_size):
context_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
# Assume each seq group finishes prefill.
input_embeds_len = 0
for i in range(batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = SequenceData.from_seqs(range(context_len))
if random.random() < prompt_embeds_ratio:
seq_data = SequenceData.from_seqs(
[],
prompt_embeds=torch.rand(context_len, 10),
)
input_embeds_len += context_len
else:
seq_data = SequenceData.from_seqs(range(context_len))
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
Expand All @@ -171,9 +197,13 @@ def test_prepare_decode_cuda_graph(batch_size):

model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_positions,
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
(input_tokens, input_positions, input_embeds, input_embeds_masks,
attn_metadata,
slot_mapping) = (model_input.input_tokens, model_input.input_positions,
model_input.input_embeds, model_input.input_embeds_masks,
model_input.attn_metadata,
model_input.attn_metadata.slot_mapping)

assert len(slot_mapping) == len(input_tokens)

expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
Expand Down Expand Up @@ -226,6 +256,8 @@ def test_prepare_decode_cuda_graph(batch_size):
assert len(input_tokens) == expected_bs
assert len(input_positions) == expected_bs
torch.allclose(input_tokens, input_positions)
assert input_embeds is None
assert input_embeds_masks is None

# Verify Sampling
expected_selected_token_indices = []
Expand Down Expand Up @@ -256,14 +288,19 @@ def test_empty_seq_group():
seq_group_metadata_list: List[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
)
(input_tokens, input_positions, input_embeds, input_embeds_masks,
attn_metadata) = (
model_input.input_tokens,
model_input.input_positions,
model_input.input_embeds,
model_input.input_embeds_masks,
model_input.attn_metadata,
)
assert input_tokens is None
assert input_positions is None
assert attn_metadata is None
assert input_embeds is None
assert input_embeds_masks is None

model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
Expand All @@ -289,9 +326,11 @@ def distributed_init():
ensure_model_parallel_initialized(1, 1)


@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("batch_size", list(range(2, 128, 3)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
@pytest.mark.parametrize('prompt_embeds_ratio', [0.0, 0.5, 1.0])
def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio,
distributed_init):
model_runner = _create_model_runner(
"facebook/opt-125m",
seed=0,
Expand All @@ -310,11 +349,19 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
block_tables = {0: [1]}
prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size
input_embeds_len = 0
for i in range(prefill_batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData.from_seqs(range(seq_len))
if random.random() < prompt_embeds_ratio:
seq_data = SequenceData.from_seqs(
[],
prompt_embeds=torch.rand(seq_len, 10),
)
input_embeds_len += seq_len
else:
seq_data = SequenceData.from_seqs(range(seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand All @@ -330,7 +377,13 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData.from_seqs(range(context_len))
if random.random() < prompt_embeds_ratio:
seq_data = SequenceData.from_seqs(
[],
prompt_embeds=torch.rand(context_len, 10),
),
else:
seq_data = SequenceData.from_seqs(range(context_len))
seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len)
seq_group_metadata = SequenceGroupMetadata(
Expand All @@ -345,11 +398,14 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
decode_metadata_list.append(seq_group_metadata)

model_input = model_runner.prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
)
(input_tokens, input_positions, attn_metadata, input_embeds,
input_embeds_masks) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.input_embeds,
model_input.input_embeds_masks,
)

prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
Expand All @@ -359,6 +415,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert attn_metadata.num_prefills == prefill_batch_size
assert attn_metadata.num_decode_tokens == decode_batch_size
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
assert len(input_embeds_masks) == sum(seq_lens)
if input_embeds_len == 0:
assert input_embeds is None
else:
assert len(input_embeds) == input_embeds_len

# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
Expand Down
Loading