From d1670d139cdad4d2dfcc6248817f4130e90e451f Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Thu, 2 Nov 2023 00:56:54 +0000 Subject: [PATCH] Fix input_metadata.selected_token_indices in worker prepare_inputs --- tests/worker/test_worker.py | 44 +++++++++++++++++++++++++++++++++++++ vllm/worker/worker.py | 4 +++- 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 tests/worker/test_worker.py diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py new file mode 100644 index 0000000000000..d7ed50acb2295 --- /dev/null +++ b/tests/worker/test_worker.py @@ -0,0 +1,44 @@ +# pylint: disable=protected-access +import random +import torch + +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.worker.worker import Worker + + +def test_worker_prepare_inputs_for_prompt(): + worker = Worker(None, None, None) + worker.block_size = 16 + batch_size = random.randint(1, 256) + prompt_lens = [] + seq_group_metadata_list = [] + for i in range(batch_size): + # make sure all tokens fit into one block + prompt_len = i % (worker.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = list(range(prompt_len)) + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData(seq_data)}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + )) + expected_selected_token_indices = [] + selected_token_start_idx = 0 + max_seq_len = max(prompt_lens) + for prompt_len in prompt_lens: + expected_selected_token_indices.append(selected_token_start_idx + + prompt_len - 1) + selected_token_start_idx += max_seq_len + input_tokens, input_positions, input_metadata = worker._prepare_inputs( + seq_group_metadata_list) + assert input_tokens.shape == input_positions.shape == (batch_size, + max_seq_len) + torch.testing.assert_close(input_tokens, input_positions) + actual = input_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d598a86cf0c1c..ade6f145a1cfa 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -218,12 +218,14 @@ def _prepare_inputs( context_lens: List[int] = [] generation_block_tables: List[List[int]] = [] max_seq_len = max(prompt_lens) if prompt_lens else 1 - for seq_group_metadata in seq_group_metadata_list: + for i, seq_group_metadata in enumerate(seq_group_metadata_list): if seq_group_metadata.is_prompt: # We need to do this in this loop as we need to know max_seq_len assert len( seq_ids) == 1, "Prompt input should have only one seq." sampling_params = seq_group_metadata.sampling_params + assert len(prompt_lens) == len(seq_group_metadata_list) + prompt_len = prompt_lens[i] if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( range(selected_token_start_idx,