Skip to content

Commit

Permalink
[Core] Factor out common code in SequenceData and Sequence (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and dtrifiro committed Sep 27, 2024
1 parent 344b05d commit 34d711b
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 97 deletions.
27 changes: 7 additions & 20 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools
import random
from array import array
from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch

Expand All @@ -12,8 +11,7 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter, is_pin_memory_available


Expand Down Expand Up @@ -59,9 +57,7 @@ def _do_sample(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
Expand Down Expand Up @@ -205,9 +201,8 @@ def create_sampling_params(min_tokens,
return sampling_params

def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE,
random.choices(range(0, VOCAB_SIZE), k=num_input)))
seq_data = SequenceData.from_seqs(
random.choices(range(0, VOCAB_SIZE), k=num_input))
if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated)
Expand Down Expand Up @@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
Expand Down Expand Up @@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1,
top_k=top_k,
Expand Down Expand Up @@ -699,11 +690,7 @@ def test_sampling_params(sampling_params: List[SamplingParams]):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0:
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
[1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params[i],
block_tables={0: [1]},
))
Expand Down
12 changes: 3 additions & 9 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from array import array
from itertools import count
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
Expand All @@ -11,8 +10,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
Expand Down Expand Up @@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data={
i:
SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
cont_token_ids[:]),
),
i: SequenceData.from_seqs(prompt_token_ids[:],
cont_token_ids[:]),
},
sampling_params=SamplingParams(temperature=0.0, ),
block_tables={i: block_allocations[i][:]},
Expand Down
8 changes: 2 additions & 6 deletions tests/test_logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import random
from array import array
from typing import Tuple
from unittest.mock import patch

Expand All @@ -9,8 +8,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -71,9 +69,7 @@ def pick_ith(token_ids, logits):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
Expand Down
7 changes: 2 additions & 5 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from array import array

import pytest

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SequenceData,
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData,
SequenceOutput)

from .core.utils import create_dummy_prompt
Expand Down Expand Up @@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs):


def test_sequence_data_prefill():
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
seq_data = SequenceData.from_seqs([1, 2, 3, 4])
assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0
# advance by 2
Expand Down
22 changes: 7 additions & 15 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import itertools
from array import array
from typing import List

import pytest
import torch

from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_cpu, make_tensor_with_pad
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import _get_graph_batch_size
Expand Down Expand Up @@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand Down Expand Up @@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))

seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
Expand Down Expand Up @@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
Expand Down
16 changes: 5 additions & 11 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from array import array
from typing import List

import pytest
Expand All @@ -8,8 +7,7 @@
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size

Expand Down Expand Up @@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_data = SequenceData.from_seqs(range(seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand Down Expand Up @@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)))
seq_data = SequenceData.from_seqs(range(context_len))
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
Expand Down Expand Up @@ -326,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_data = SequenceData.from_seqs(range(seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand All @@ -343,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))
seq_data = SequenceData(prompt_toks)
seq_data = SequenceData.from_seqs(range(context_len))
seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len)
seq_group_metadata = SequenceGroupMetadata(
Expand Down
8 changes: 1 addition & 7 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
from array import array
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
Expand All @@ -22,10 +21,6 @@

C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)

# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE = "l"


@dataclass(frozen=True)
class InputContext:
Expand Down Expand Up @@ -130,8 +125,7 @@ def _default_dummy_data_factory(
# Avoid circular import
from vllm.sequence import SequenceData

dummy_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
dummy_seq_data = SequenceData.from_counts({0: seq_len})
dummy_multi_modal_data = None

return dummy_seq_data, dummy_multi_modal_data
Expand Down
61 changes: 37 additions & 24 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from array import array
from collections import defaultdict
from dataclasses import dataclass
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
Expand Down Expand Up @@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None

@staticmethod
def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData":
if len(counts_by_token) == 0:
return SequenceData.from_seqs([])

arrs = [
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
for token_id, count in counts_by_token.items()
]

return SequenceData(reduce(array.__add__, arrs))

@staticmethod
def from_seqs(
prompt_token_ids: GenericSequence[int],
output_token_ids: Optional[GenericSequence[int]] = None,
) -> "SequenceData":
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
prompt_token_ids)

if output_token_ids is None:
return SequenceData(prompt_token_ids_arr)

output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
output_token_ids)

return SequenceData(prompt_token_ids_arr,
_output_token_ids=output_token_ids_arr)

def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l"
Expand Down Expand Up @@ -370,8 +400,6 @@ def __init__(
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.from_decoder_prompt = from_decoder_prompt
self._prompt: Optional[str] = None
self._prompt_token_ids: Optional[List[int]] = None

# For decoder-only models, a Sequence is constructed
# from an LLMInputs instance (the `inputs` arg.)
Expand Down Expand Up @@ -400,8 +428,7 @@ def __init__(
f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?")

self.data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

Expand All @@ -422,37 +449,23 @@ def __init__(
def n_blocks(self) -> int:
return (self.get_len() + self.block_size - 1) // self.block_size

@property
@cached_property
def prompt(self) -> Optional[str]:
if self._prompt is not None:
# Reuse precomputed prompt string
return self._prompt

# Select decoder or encoder input prompt str,
# as appropriate
# Select decoder or encoder input prompt str, as appropriate
prompt_key: str = ("prompt"
if self.from_decoder_prompt else "encoder_prompt")

# Cache prompt
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
return self._prompt
return cast(Optional[str], self.inputs.get(prompt_key))

@property
@cached_property
def prompt_token_ids(self) -> List[int]:
if self._prompt_token_ids is not None:
# Reuse precomputed prompt token ids
return self._prompt_token_ids

# Select decoder or encoder input prompt
# token ids, as appropriate
# Select decoder or encoder input prompt token ids, as appropriate
prompt_token_ids_key: str = ("prompt_token_ids"
if self.from_decoder_prompt else
"encoder_prompt_token_ids")

# Cache computed prompt token ids
self._prompt_token_ids = cast(List[int],
self.inputs.get(prompt_token_ids_key))
return self._prompt_token_ids
return cast(List[int], self.inputs.get(prompt_token_ids_key))

@property
def multi_modal_data(self) -> "MultiModalDataDict":
Expand Down

0 comments on commit 34d711b

Please sign in to comment.