diff --git a/requirements-common.txt b/requirements-common.txt index b0e599a5e5af5..b6bed8a73d8c8 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq +msgspec librosa # Required for audio processing soundfile # Required for audio processing gguf == 0.9.1 diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 7aed0d5e1fa69..7c62de9fa9e37 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -8,6 +8,7 @@ import pytest from prometheus_client import REGISTRY +import vllm.envs as envs from vllm import SamplingParams from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, ENABLE_ARTIFICIAL_PREEMPT) @@ -24,6 +25,13 @@ "tests/basic_correctness/test_preemption.py`") +@pytest.fixture +def worker_use_ray() -> bool: + # When SPMD worker is used, use ray_use_worker=True + # to test delta input optimization works with preemption. + return envs.VLLM_USE_RAY_SPMD_WORKER + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [96]) @@ -36,6 +44,7 @@ def test_chunked_prefill_recompute( dtype: str, max_tokens: int, chunked_prefill_token_size: int, + worker_use_ray: bool, ) -> None: """Ensure that chunked prefill works with preemption.""" max_num_seqs = min(chunked_prefill_token_size, 256) @@ -54,6 +63,7 @@ def test_chunked_prefill_recompute( max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=enable_chunked_prefill, max_num_seqs=max_num_seqs, + worker_use_ray=worker_use_ray, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt @@ -79,6 +89,7 @@ def test_preemption( model: str, dtype: str, max_tokens: int, + worker_use_ray: bool, ) -> None: """By default, recompute preemption is enabled""" @@ -89,6 +100,7 @@ def test_preemption( model, dtype=dtype, disable_log_stats=False, + worker_use_ray=worker_use_ray, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt @@ -132,6 +144,7 @@ def test_swap( dtype: str, max_tokens: int, beam_width: int, + worker_use_ray: bool, ) -> None: """Use beam search enables swapping.""" example_prompts = example_prompts[:1] @@ -144,6 +157,7 @@ def test_swap( dtype=dtype, swap_space=10, disable_log_stats=False, + worker_use_ray=worker_use_ray, ) as vllm_model: vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, max_tokens) @@ -188,6 +202,7 @@ def test_swap_infeasible( dtype: str, max_tokens: int, beam_width: int, + worker_use_ray: bool, ) -> None: """Verify infeasible swap request will be ignored.""" BLOCK_SIZE = 16 @@ -204,6 +219,7 @@ def test_swap_infeasible( # decode blocks are not enough to finish. num_gpu_blocks_override=prefill_blocks + decode_blocks, max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, + worker_use_ray=worker_use_ray, ) as vllm_model: sampling_params = SamplingParams(n=beam_width, use_beam_search=True, @@ -230,6 +246,7 @@ def test_preemption_infeasible( model: str, dtype: str, max_tokens: int, + worker_use_ray: bool, ) -> None: """Verify infeasible preemption request will be ignored.""" BLOCK_SIZE = 16 @@ -244,6 +261,7 @@ def test_preemption_infeasible( # ignored instead of hanging forever. num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), + worker_use_ray=worker_use_ray, ) as vllm_model: sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py new file mode 100644 index 0000000000000..d604e5250a3f9 --- /dev/null +++ b/tests/core/test_serialization.py @@ -0,0 +1,33 @@ +import msgspec + +from vllm.executor.msgspec_utils import decode_hook, encode_hook +from vllm.sequence import ExecuteModelRequest + +from ..spec_decode.utils import create_batch + + +def test_msgspec_serialization(): + num_lookahead_slots = 4 + seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=num_lookahead_slots, + running_queue_size=4) + + encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, + dec_hook=decode_hook) + req = decoder.decode(encoder.encode(execute_model_req)) + expected = execute_model_req.seq_group_metadata_list + actual = req.seq_group_metadata_list + assert (len(expected) == len(actual)) + expected = expected[0] + actual = actual[0] + + assert expected.block_tables == actual.block_tables + assert expected.is_prompt == actual.is_prompt + assert expected.request_id == actual.request_id + assert (expected.seq_data[0].prompt_token_ids == + actual.seq_data[0].prompt_token_ids) + assert (expected.seq_data[0].output_token_ids == + actual.seq_data[0].output_token_ids) diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 1de2ebab22db4..e254686f269b1 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -22,7 +22,8 @@ @pytest.mark.skipif(cuda_device_count_stateless() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize( - "model, distributed_executor_backend, attention_backend, test_suite", [ + "model, distributed_executor_backend, attention_backend, " + "test_suite", [ ("facebook/opt-125m", "ray", "", "L4"), ("facebook/opt-125m", "mp", "", "L4"), ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 10921a3852f81..262845f19822f 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -6,6 +6,8 @@ ``` """ +import os + import pytest from vllm.utils import cuda_device_count_stateless @@ -30,6 +32,11 @@ def test_models( model: str, distributed_executor_backend: str, ) -> None: + if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray": # noqa + assert distributed_executor_backend == "ray" + # test ray adag + os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" + os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" dtype = "half" max_tokens = 5 diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 74e7486e8012e..820fb554888f0 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,5 +1,6 @@ import itertools import random +from array import array from typing import Dict, List, Optional, Tuple from unittest.mock import Mock, patch @@ -10,7 +11,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, + SequenceData, SequenceGroupMetadata) from vllm.utils import Counter, is_pin_memory_available @@ -56,7 +58,9 @@ def _do_sample( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data={ + 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) + }, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -201,7 +205,8 @@ def create_sampling_params(min_tokens, def create_sequence_data(num_input=3, num_generated=0): seq_data = SequenceData( - random.choices(range(0, VOCAB_SIZE), k=num_input)) + array(VLLM_TOKEN_ID_ARRAY_TYPE, + random.choices(range(0, VOCAB_SIZE), k=num_input))) if num_generated > 0: seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), k=num_generated) @@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data={ + 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) + }, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data={ + 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) + }, sampling_params=SamplingParams( temperature=1, top_k=top_k, @@ -650,7 +659,11 @@ def test_sampling_params(sampling_params: List[SamplingParams]): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data={ + 0: + SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, + [1, 2, 3])) + }, sampling_params=sampling_params[i], block_tables={0: [1]}, )) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 30eb99f868bfc..60b36a33d9077 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -1,3 +1,4 @@ +from array import array from itertools import count from typing import Callable, Dict, List, Optional from typing import Sequence as GenericSequence @@ -9,7 +10,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, + CompletionSequenceGroupOutput, Logprob, SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port @@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts( seq_data={ i: SequenceData( - prompt_token_ids=prompt_token_ids[:], - output_token_ids=cont_token_ids[:], + array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]), + _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, + cont_token_ids[:]), ), }, sampling_params=SamplingParams(temperature=0.0, ), diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 7d4af963e25c5..1ce49a50688ae 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -1,4 +1,5 @@ import random +from array import array from typing import Tuple from unittest.mock import patch @@ -8,7 +9,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, + SequenceData, SequenceGroupMetadata) from vllm.utils import is_pin_memory_available @@ -69,7 +71,9 @@ def pick_ith(token_ids, logits): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data={ + 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) + }, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), block_tables={0: [1]}, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 3136402518b9f..1ae349e808e0d 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,6 +1,9 @@ +from array import array + import pytest -from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput, +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, + CompletionSequenceGroupOutput, SamplerOutput, SequenceData, SequenceOutput) from .core.utils import create_dummy_prompt @@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs): def test_sequence_data_prefill(): - seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4]) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4])) assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_computed_tokens() == 0 # advance by 2 diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 8a2e9b81580fc..32bff22f66a8b 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -1,10 +1,12 @@ +from array import array from typing import List import pytest import torch from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, + SequenceData, SequenceGroupMetadata) from vllm.utils import is_cpu from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner @@ -125,10 +127,12 @@ def test_prepare_prompt( # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(list(range(seq_len))) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, + range(seq_len))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceData(list(range(encoder_seq_len))) + encoder_seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -319,10 +323,12 @@ def test_prepare_decode( # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(list(range(seq_len))) + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceData(list(range(encoder_seq_len))) + encoder_seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 84502043cbd26..a20aa37bcc1e2 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,3 +1,4 @@ +from array import array from typing import List import pytest @@ -7,7 +8,8 @@ init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, + SequenceData, SequenceGroupMetadata) from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -46,7 +48,8 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(list(range(seq_len))) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, + range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -163,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) - seq_data = SequenceData(list(range(context_len))) + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))) seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. seq_data.append_token_id(1, 0) @@ -324,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(list(range(seq_len))) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, + range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -340,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(context_len)) + prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)) seq_data = SequenceData(prompt_toks) seq_data.append_token_id(1, 0) seq_data.update_num_computed_tokens(context_len) diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py index f98adeba1c705..2bb17fdc01109 100644 --- a/vllm/adapter_commons/request.py +++ b/vllm/adapter_commons/request.py @@ -1,8 +1,6 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass -@dataclass class AdapterRequest(ABC): """ Base class for adapter requests. diff --git a/vllm/config.py b/vllm/config.py index e03adb5f5c963..781441427b3ad 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -741,8 +741,8 @@ def __init__( self.tokenizer_pool_config = tokenizer_pool_config self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group - self.world_size = pipeline_parallel_size * self.tensor_parallel_size + if worker_use_ray: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" @@ -838,6 +838,11 @@ class SchedulerConfig: swapping. However, when the sequence group has multiple sequences (e.g., beam search), recomputation is not currently supported. In such a case, we use swapping instead. + send_delta_data: Private API. If used, scheduler sends delta data to + workers instead of an entire data. It should be enabled only + when SPMD worker architecture is enabled. I.e., + VLLM_USE_RAY_SPMD_WORKER=1 + """ def __init__(self, @@ -850,7 +855,8 @@ def __init__(self, enable_chunked_prefill: bool = False, embedding_mode: Optional[bool] = False, preemption_mode: Optional[str] = None, - num_scheduler_steps: int = 1) -> None: + num_scheduler_steps: int = 1, + send_delta_data: bool = False) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: @@ -880,6 +886,7 @@ def __init__(self, self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode self.num_scheduler_steps = num_scheduler_steps + self.send_delta_data = send_delta_data self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 287de60149670..802359d2283f7 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -12,7 +12,8 @@ from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceStatus) + SequenceGroupMetadata, SequenceGroupMetadataDelta, + SequenceStatus) from vllm.utils import PyObjectCache logger = init_logger(__name__) @@ -363,8 +364,6 @@ def __init__( self.num_cumulative_preemption: int = 0 # Used to cache python objects - self._seq_group_metadata_cache: PyObjectCache = PyObjectCache( - seq_group_metadata_builder) self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache( scheduler_running_outputs_builder) self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache( @@ -1048,15 +1047,10 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) - seq_group_metadata = self._seq_group_metadata_cache.get_object() - seq_group_metadata.seq_data.clear() - seq_group_metadata.block_tables.clear() - # seq_id -> SequenceData - seq_data: Dict[int, SequenceData] = seq_group_metadata.seq_data + seq_data: Dict[int, SequenceData] = {} # seq_id -> physical block numbers - block_tables: Dict[int, - List[int]] = seq_group_metadata.block_tables + block_tables: Dict[int, List[int]] = {} if seq_group.is_encoder_decoder(): # Encoder associated with SequenceGroup @@ -1081,45 +1075,65 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_group.get_seqs(status=SequenceStatus.RUNNING))) do_sample = True - if seq_group.is_prefill(): + is_prompt = seq_group.is_prefill() + # We should send the metadata to workers when the first prefill + # is sent. Subsequent requests could be chunked prefill or decode. + is_first_prefill = False + if is_prompt: seqs = seq_group.get_seqs() # Prefill has only 1 sequence. assert len(seqs) == 1 + num_computed_tokens = seqs[0].data.get_num_computed_tokens() + is_first_prefill = num_computed_tokens == 0 # In the next iteration, all prompt tokens are not computed. # It means the prefill is chunked, and we don't need sampling. # NOTE: We use get_len instead of get_prompt_len because when # a sequence is preempted, prefill includes previous generated # output tokens. - if (token_chunk_size + seqs[0].data.get_num_computed_tokens() < + if (token_chunk_size + num_computed_tokens < seqs[0].data.get_len()): do_sample = False # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. - is_prompt = seq_group.is_prefill() - - seq_group_metadata.__init__( - request_id=seq_group.request_id, - is_prompt=is_prompt, - seq_data=seq_data, - sampling_params=seq_group.sampling_params, - block_tables=block_tables, - do_sample=do_sample, - pooling_params=seq_group.pooling_params, - token_chunk_size=token_chunk_size, - lora_request=seq_group.lora_request, - computed_block_nums=common_computed_block_nums, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - state=seq_group.state, - # `multi_modal_data` will only be present for the 1st comm - # between engine and worker. - # the subsequent comms can still use delta, but - # `multi_modal_data` will be None. - multi_modal_data=seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups > 0 else None, - prompt_adapter_request=seq_group.prompt_adapter_request, - ) + if is_first_prefill or not self.scheduler_config.send_delta_data: + seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group.request_id, + is_prompt=is_prompt, + seq_data=seq_data, + sampling_params=seq_group.sampling_params, + block_tables=block_tables, + do_sample=do_sample, + pooling_params=seq_group.pooling_params, + token_chunk_size=token_chunk_size, + lora_request=seq_group.lora_request, + computed_block_nums=common_computed_block_nums, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + state=seq_group.state, + # `multi_modal_data` will only be present for the 1st comm + # between engine and worker. + # the subsequent comms can still use delta, but + # `multi_modal_data` will be None. + multi_modal_data=seq_group.multi_modal_data + if scheduler_outputs.num_prefill_groups > 0 else None, + prompt_adapter_request=seq_group.prompt_adapter_request, + ) + else: + # When SPMD mode is enabled, we only send delta data except for + # the first request to reduce serialization cost. + seq_data_delta = {} + for id, data in seq_data.items(): + seq_data_delta[id] = data.get_delta_and_reset() + seq_group_metadata = SequenceGroupMetadataDelta( + seq_data_delta, + seq_group.request_id, + block_tables, + is_prompt, + do_sample=do_sample, + token_chunk_size=token_chunk_size, + computed_block_nums=common_computed_block_nums, + ) seq_group_metadata_list.append(seq_group_metadata) # Now that the batch has been created, we can assume all blocks in the @@ -1130,8 +1144,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: self.block_manager.mark_blocks_as_computed( scheduled_seq_group.seq_group) - self._seq_group_metadata_cache.reset() - scheduler_time = time.perf_counter() - scheduler_start_time # Add this to scheduler time to all the sequences that are currently # running. This will help estimate if the scheduler is a significant diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6c7259129a109..ad822e51d4512 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, Union) +import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, @@ -907,6 +908,8 @@ def create_engine_config(self, ) -> EngineConfig: embedding_mode=model_config.embedding_mode, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, + send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER + and parallel_config.use_ray), ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 979555eb6a05d..a62f0b599652f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -228,7 +228,6 @@ def __init__( cache_config.enable_prefix_caching, ) # TODO(woosuk): Print more configs in debug mode. - from vllm.plugins import load_general_plugins load_general_plugins() diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py new file mode 100644 index 0000000000000..c467115f124ca --- /dev/null +++ b/vllm/executor/msgspec_utils.py @@ -0,0 +1,27 @@ +from array import array +from typing import Any, Type + +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE + + +def encode_hook(obj: Any) -> Any: + """Custom msgspec enc hook that supports array types. + + See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder + """ + if isinstance(obj, array): + assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, ( + f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " + f"Given array has a type code of {obj.typecode}.") + return obj.tobytes() + + +def decode_hook(type: Type, obj: Any) -> Any: + """Custom msgspec dec hook that supports array types. + + See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder + """ + if type is array: + deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) + deserialized.frombytes(obj) + return deserialized diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index fa3646012dd6e..3a08ab4dbfd44 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -4,9 +4,12 @@ from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +import msgspec + import vllm.envs as envs from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.msgspec_utils import encode_hook from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -60,6 +63,10 @@ def _init_executor(self) -> None: # Create the parallel GPU workers. self._init_workers_ray(placement_group) + self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + self.output_decoder = msgspec.msgpack.Decoder( + Optional[List[SamplerOutput]]) + def shutdown(self) -> None: if hasattr(self, "forward_dag") and self.forward_dag is not None: self.forward_dag.teardown() @@ -123,6 +130,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", ray_remote_kwargs) logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) + # Create the workers. driver_ip = get_ip() worker_wrapper_kwargs = self._get_worker_wrapper_args() @@ -304,8 +312,10 @@ def execute_model( if self.forward_dag is None: self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) - outputs = ray.get(self.forward_dag.execute(execute_model_req)) - return outputs[0] + serialized_data = self.input_encoder.encode(execute_model_req) + outputs = ray.get(self.forward_dag.execute(serialized_data)) + output = self.output_decoder.decode(outputs[0]) + return output def _run_workers( self, @@ -475,9 +485,10 @@ async def execute_model_async( if self.forward_dag is None: self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) - dag_future = await self.forward_dag.execute_async(execute_model_req) + serialized_data = self.input_encoder.encode(execute_model_req) + dag_future = await self.forward_dag.execute_async(serialized_data) outputs = await dag_future - return outputs[0] + return self.output_decoder.decode(outputs[0]) async def _driver_execute_model_async( self, diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index ab283467d4783..ffc94d07ed399 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,6 +1,9 @@ from typing import List, Optional, Tuple, Union +import msgspec + from vllm.config import ParallelConfig +from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors @@ -24,6 +27,10 @@ def __init__(self, *args, **kwargs) -> None: # that thread. self.compiled_dag_cuda_device_set = False + self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, + dec_hook=decode_hook) + self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + def get_node_ip(self) -> str: return get_ip() @@ -33,16 +40,26 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: return node_id, gpu_ids def execute_model_spmd( - self, req_or_tuple: Union[ExecuteModelRequest, - Tuple[ExecuteModelRequest, - IntermediateTensors]]): + self, req_or_tuple: Union[bytes, + Tuple[bytes, + Optional[IntermediateTensors]]] + ) -> bytes: """Execute model in SPMD fashion: used only when SPMD worker and compiled DAG are both enabled. Args: - req_or_tuple: The request to execute the model, or a tuple - containing the request and intermediate tensors. + req_or_tuple: A request or a tuple containing the + request and intermediate tensors. Intermediate tensors are + None unless if it is provided because it is > 0 pipeline + stage. The request is serialized by msgspec. """ + if isinstance(req_or_tuple, bytes): + serialized_req, intermediate_tensors = req_or_tuple, None + else: + serialized_req, intermediate_tensors = req_or_tuple + + execute_model_req = self.input_decoder.decode(serialized_req) + # TODO(swang): This is needed right now because Ray aDAG executes # on a background thread, so we need to reset torch's current # device. @@ -51,16 +68,14 @@ def execute_model_spmd( torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - if isinstance(req_or_tuple, tuple): - execute_model_req, intermediate_tensors = req_or_tuple - else: - execute_model_req = req_or_tuple - intermediate_tensors = None - output = self.worker._execute_model_spmd(execute_model_req, intermediate_tensors) + # Pipeline model request and output to the next pipeline stage. if isinstance(output, IntermediateTensors): - return execute_model_req, output + output = serialized_req, output + else: + output = self.output_encoder.encode(output) + return output ray_import_err = None diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 2ca8b10f71593..7c17895bf1d1a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,4 +1,5 @@ import functools +from array import array from collections import UserDict from dataclasses import dataclass from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol, @@ -21,6 +22,10 @@ C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) +# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE. +# We cannot import it here because of circular dependencies. +VLLM_TOKEN_ID_ARRAY_TYPE = "l" + @dataclass(frozen=True) class InputContext: @@ -132,7 +137,8 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceData - dummy_seq_data = SequenceData([0] * seq_len) + dummy_seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len) dummy_multi_modal_data = None return dummy_seq_data, dummy_multi_modal_data diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 5d791424fbe6e..d770da4f2407d 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -1,12 +1,15 @@ import warnings -from dataclasses import dataclass, field from typing import Optional +import msgspec + from vllm.adapter_commons.request import AdapterRequest -@dataclass -class LoRARequest(AdapterRequest): +class LoRARequest( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """ Request for a LoRA adapter. @@ -18,16 +21,17 @@ class LoRARequest(AdapterRequest): lora_int_id must be globally unique for a given adapter. This is currently not enforced in vLLM. """ + __metaclass__ = AdapterRequest lora_name: str lora_int_id: int lora_path: str = "" - lora_local_path: Optional[str] = field(default=None, repr=False) + lora_local_path: Optional[str] = msgspec.field(default=None) long_lora_max_len: Optional[int] = None __hash__ = AdapterRequest.__hash__ def __post_init__(self): - if 'lora_local_path' in self.__dict__: + if 'lora_local_path' in self.__struct_fields__: warnings.warn( "The 'lora_local_path' attribute is deprecated " "and will be removed in a future version. " diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index a6fd5f58b3cb6..69e777152e3d4 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -1,5 +1,6 @@ """Minimal implementation of BlipVisionModel intended to be only used within a vision language model.""" +from array import array from typing import Optional, Union import torch @@ -16,7 +17,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) -from vllm.sequence import SequenceData +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -53,8 +54,10 @@ def dummy_seq_data_for_blip( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size - token_ids += [0] * (seq_len - image_feature_size) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * image_feature_size + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size) return SequenceData(token_ids) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 386dfeb5bb1e5..8cfd3c2672568 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,3 +1,4 @@ +from array import array from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -17,7 +18,8 @@ from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SamplerOutput, SequenceData) from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) @@ -427,8 +429,10 @@ def dummy_seq_data_for_blip2( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size * num_images - token_ids += [0] * (seq_len - image_feature_size * num_images) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * image_feature_size * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 6776b93d126b0..788d22db9d5a8 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,3 +1,4 @@ +from array import array from functools import cached_property from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict) @@ -31,7 +32,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SamplerOutput, SequenceData) from vllm.utils import print_warning_once from .interfaces import SupportsMultiModal @@ -70,8 +72,10 @@ def dummy_seq_data_for_chameleon( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size * num_images - token_ids += [0] * (seq_len - image_feature_size * num_images) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * image_feature_size * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index fcd360ce8fd72..24eeefdfccf00 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,5 +1,6 @@ """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" +from array import array from typing import Iterable, Optional, Tuple import torch @@ -17,7 +18,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) -from vllm.sequence import SequenceData +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -53,8 +54,10 @@ def dummy_seq_data_for_clip( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size * num_images - token_ids += [0] * (seq_len - image_feature_size * num_images) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * image_feature_size * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index e8184e466c5bf..2ef23819b69a2 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -16,6 +16,7 @@ # limitations under the License. """ PyTorch Fuyu model.""" import math +from array import array from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict import torch @@ -37,7 +38,8 @@ from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import (cached_get_image_processor, cached_get_tokenizer) -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SamplerOutput, SequenceData) from .interfaces import SupportsMultiModal from .utils import merge_multimodal_embeddings @@ -97,9 +99,12 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): ncol, nrow = get_max_fuyu_image_feature_size() image_feature_size = get_max_fuyu_image_tokens(ctx) - image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow - token_ids = image_token_ids * num_images - token_ids += [0] * (seq_len - image_feature_size * num_images) + image_token_ids = ( + array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol + + array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index ef2323398abd0..729bd27c334d5 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -23,6 +23,7 @@ """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re +from array import array from functools import partial from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, TypedDict, Union) @@ -55,7 +56,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import (cached_get_image_processor, cached_get_tokenizer) -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SamplerOutput, SequenceData) from .idefics2_vision_model import Idefics2VisionTransformer @@ -408,7 +410,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): - token_ids = [0] * seq_len + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len return SequenceData(token_ids) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 4df8c0b54201c..426af7fee9544 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -2,6 +2,7 @@ within a vision language model.""" import math +from array import array from typing import Iterable, Optional, Tuple import torch @@ -25,7 +26,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) -from vllm.sequence import SequenceData +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -62,8 +63,10 @@ def dummy_seq_data_for_siglip( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size * num_images - token_ids += [0] * (seq_len - image_feature_size * num_images) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * image_feature_size + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size) return SequenceData(token_ids) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 94b4b14416821..a085779bc61a7 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -6,7 +6,8 @@ import torch from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SequenceData, SequenceGroupMetadata +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, + SequenceGroupMetadata) from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.utils import (PyObjectCache, async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad, @@ -505,9 +506,11 @@ def from_sampling_metadata( and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) prompt_tokens.extend( - array('l') for _ in range(prefill_len)) + array(VLLM_TOKEN_ID_ARRAY_TYPE) + for _ in range(prefill_len)) output_tokens.extend( - array('l') for _ in range(prefill_len)) + array(VLLM_TOKEN_ID_ARRAY_TYPE) + for _ in range(prefill_len)) if seq_group.do_sample: for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 3b95d73ddc2c5..7461fb51989c6 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,15 +1,18 @@ from typing import Any, Optional +import msgspec -class PoolingParams: + +class PoolingParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """Pooling parameters for pooling. Attributes: additional_data: Any additional data needed for pooling. """ - - def __init__(self, additional_data: Optional[Any] = None): - self.additional_data = additional_data + additional_data: Optional[Any] = None def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" diff --git a/vllm/prompt_adapter/request.py b/vllm/prompt_adapter/request.py index c0c98cf72bbae..775dd11db0719 100644 --- a/vllm/prompt_adapter/request.py +++ b/vllm/prompt_adapter/request.py @@ -1,13 +1,17 @@ -from dataclasses import dataclass +import msgspec from vllm.adapter_commons.request import AdapterRequest -@dataclass -class PromptAdapterRequest(AdapterRequest): +class PromptAdapterRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + frozen=True): # type: ignore[call-arg] """ Request for a Prompt adapter. """ + __metaclass__ = AdapterRequest prompt_adapter_name: str prompt_adapter_id: int diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 04250c682cd23..7197b51398538 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -2,10 +2,10 @@ import copy from enum import IntEnum from functools import cached_property -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union +import msgspec import torch -from pydantic import Field from typing_extensions import Annotated from vllm.logger import init_logger @@ -33,7 +33,11 @@ class SamplingType(IntEnum): to sample from.""" -class SamplingParams: +class SamplingParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): # type: ignore[call-arg] """Sampling parameters for text generation. Overall, we follow the sampling parameters from the OpenAI text completion @@ -112,87 +116,73 @@ class SamplingParams: (i.e., no truncation). """ - def __init__( - self, - n: int = 1, - best_of: Optional[int] = None, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repetition_penalty: float = 1.0, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - min_p: float = 0.0, - seed: Optional[int] = None, - use_beam_search: bool = False, - length_penalty: float = 1.0, - early_stopping: Union[bool, str] = False, - stop: Optional[Union[str, List[str]]] = None, - stop_token_ids: Optional[List[int]] = None, - include_stop_str_in_output: bool = False, - ignore_eos: bool = False, - max_tokens: Optional[int] = 16, - min_tokens: int = 0, - logprobs: Optional[int] = None, - prompt_logprobs: Optional[int] = None, - detokenize: bool = True, - skip_special_tokens: bool = True, - spaces_between_special_tokens: bool = True, - logits_processors: Optional[List[LogitsProcessor]] = None, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, - ) -> None: - self.n = n - self.best_of = best_of if best_of is not None else n - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.repetition_penalty = repetition_penalty - if 0 < temperature < _MAX_TEMP: + n: int = 1 + best_of: Optional[int] = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + repetition_penalty: float = 1.0 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + min_p: float = 0.0 + seed: Optional[int] = None + use_beam_search: bool = False + length_penalty: float = 1.0 + early_stopping: Union[bool, str] = False + stop: Optional[Union[str, List[str]]] = None + stop_token_ids: Optional[List[int]] = None + ignore_eos: bool = False + max_tokens: Optional[int] = 16 + min_tokens: int = 0 + logprobs: Optional[int] = None + prompt_logprobs: Optional[int] = None + # NOTE: This parameter is only exposed at the engine level for now. + # It is not exposed in the OpenAI API server, as the OpenAI API does + # not support returning only a list of token IDs. + detokenize: bool = True + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + # Optional[List[LogitsProcessor]] type. We use Any here because + # Optional[List[LogitsProcessor]] type is not supported by msgspec. + logits_processors: Optional[Any] = None + include_stop_str_in_output: bool = False + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None + + # The below fields are not supposed to be used as an input. + # They are set in post_init. + output_text_buffer_length: int = 0 + _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set) + + def __post_init__(self) -> None: + self.best_of = self.best_of or self.n + if 0 < self.temperature < _MAX_TEMP: logger.warning( "temperature %s is less than %s, which may cause numerical " "errors nan or inf in tensors. We have maxed it out to %s.", - temperature, _MAX_TEMP, _MAX_TEMP) - temperature = max(temperature, _MAX_TEMP) - self.temperature = temperature - self.top_p = top_p - self.top_k = top_k - self.min_p = min_p - if seed == -1: + self.temperature, _MAX_TEMP, _MAX_TEMP) + self.temperature = max(self.temperature, _MAX_TEMP) + if self.seed == -1: self.seed = None else: - self.seed = seed - self.use_beam_search = use_beam_search - self.length_penalty = length_penalty - self.early_stopping = early_stopping - if stop is None: + self.seed = self.seed + if self.stop is None: self.stop = [] - elif isinstance(stop, str): - self.stop = [stop] + elif isinstance(self.stop, str): + self.stop = [self.stop] else: - self.stop = list(stop) - if stop_token_ids is None: + self.stop = list(self.stop) + if self.stop_token_ids is None: self.stop_token_ids = [] else: - self.stop_token_ids = list(stop_token_ids) - self.ignore_eos = ignore_eos - self.max_tokens = max_tokens - self.min_tokens = min_tokens - self.logprobs = 1 if logprobs is True else logprobs - self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs - # NOTE: This parameter is only exposed at the engine level for now. - # It is not exposed in the OpenAI API server, as the OpenAI API does - # not support returning only a list of token IDs. - self.detokenize = detokenize - self.skip_special_tokens = skip_special_tokens - self.spaces_between_special_tokens = spaces_between_special_tokens - self.logits_processors = logits_processors - self.include_stop_str_in_output = include_stop_str_in_output - self.truncate_prompt_tokens = truncate_prompt_tokens + self.stop_token_ids = list(self.stop_token_ids) + self.logprobs = 1 if self.logprobs is True else self.logprobs + self.prompt_logprobs = (1 if self.prompt_logprobs is True else + self.prompt_logprobs) + # Number of characters to hold back for stop string evaluation # until sequence is finished. - if self.stop and not include_stop_str_in_output: + if self.stop and not self.include_stop_str_in_output: self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 - else: - self.output_text_buffer_length = 0 self._verify_args() if self.use_beam_search: @@ -206,11 +196,12 @@ def __init__( self.min_p = 0.0 self._verify_greedy_sampling() # eos_token_id is added to this by the engine - self.all_stop_token_ids = set(self.stop_token_ids) + self._all_stop_token_ids = set(self.stop_token_ids) def _verify_args(self) -> None: if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") + assert isinstance(self.best_of, int) if self.best_of < self.n: raise ValueError(f"best_of must be greater than or equal to n, " f"got n={self.n} and best_of={self.best_of}.") @@ -257,6 +248,7 @@ def _verify_args(self) -> None: and self.truncate_prompt_tokens < 1): raise ValueError(f"truncate_prompt_tokens must be >= 1, " f"got {self.truncate_prompt_tokens}") + assert isinstance(self.stop, list) if any(not stop_str for stop_str in self.stop): raise ValueError("stop cannot contain an empty string.") if self.stop and not self.detokenize: @@ -290,6 +282,7 @@ def _verify_non_beam_search(self) -> None: "default value of 1.0 when not using beam search.") def _verify_greedy_sampling(self) -> None: + assert isinstance(self.best_of, int) if self.best_of > 1: raise ValueError("best_of must be 1 when using greedy sampling." f"Got {self.best_of}.") @@ -303,7 +296,7 @@ def update_from_generation_config( if model_eos_token_id is not None: # Add the eos token id into the sampling_params to support # min_tokens processing. - self.all_stop_token_ids.add(model_eos_token_id) + self._all_stop_token_ids.add(model_eos_token_id) # Update eos_token_id for generation if (eos_ids := generation_config.get("eos_token_id")) is not None: @@ -315,7 +308,7 @@ def update_from_generation_config( # purposes. eos_ids.discard(model_eos_token_id) if eos_ids: - self.all_stop_token_ids.update(eos_ids) + self._all_stop_token_ids.update(eos_ids) if not self.ignore_eos: eos_ids.update(self.stop_token_ids) self.stop_token_ids = list(eos_ids) @@ -330,6 +323,10 @@ def sampling_type(self) -> SamplingType: return SamplingType.RANDOM_SEED return SamplingType.RANDOM + @property + def all_stop_token_ids(self) -> Set[int]: + return self._all_stop_token_ids + def clone(self) -> "SamplingParams": """Deep copy excluding LogitsProcessor objects. diff --git a/vllm/sequence.py b/vllm/sequence.py index b83e345235cdd..b15955cde76cf 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -4,10 +4,11 @@ from abc import ABC, abstractmethod from array import array from collections import defaultdict -from dataclasses import dataclass, field -from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, - Union, cast) +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, + Tuple, Union, cast) +import msgspec import numpy import torch @@ -16,13 +17,18 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: from vllm.inputs import LLMInputs - from vllm.multimodal import MultiModalDataDict - from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + from vllm.multimodal.base import MultiModalDataDict +VLLM_TOKEN_ID_ARRAY_TYPE = "l" + +# We use dataclass for now because it is used for +# openai server output, and msgspec is not serializable. +# TODO(sang): Fix it. @dataclass class Logprob: """Infos for supporting OpenAI compatible logprobs and token ranks. @@ -112,7 +118,23 @@ class RequestMetrics: model_execute_time: Optional[float] = None -class SequenceData: +class SequenceDataDelta( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """Delta SequenceData to send to workers per step.""" + # A new token to be appended to existing SequenceData. + new_output_token_ids: List[int] + # Overwriting existing `cumulative_logprob` + new_cumulative_logprob: float + # Overwriting existing `num_computed_tokens`. + new_num_computed_tokens: int + # Overwriting existing `stage`. + new_stage: SequenceStage + + +class SequenceData(msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] """Data associated with a sequence. Args: @@ -125,40 +147,57 @@ class SequenceData: output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. """ - - def __init__( - self, - prompt_token_ids: List[int], - output_token_ids: Optional[List[int]] = None, - ) -> None: - self._prompt_token_ids = array('l', prompt_token_ids) - self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) - self._output_token_ids = array( - 'l', output_token_ids if output_token_ids is not None else []) - - self.cumulative_logprob = 0.0 - # The number of tokens that are computed (that run against the model). - self._num_computed_tokens = 0 - self._stage: SequenceStage = SequenceStage.PREFILL - + # NOTE: we cannot use Union[List, array] because msgspec cannot support + # union of 2 list types. + _prompt_token_ids: array + _output_token_ids: array = msgspec.field( + default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) + + ### The below fields should not be passed as an argument ### + _cumulative_logprob: float = 0.0 + _prompt_token_ids_tuple: Tuple[int, + ...] = msgspec.field(default_factory=tuple) + # The number of tokens that are computed (that run against the model). + _num_computed_tokens: int = 0 + _stage: SequenceStage = SequenceStage.PREFILL + _cached_all_token_ids: List[int] = msgspec.field(default_factory=list) + + # It is used to get delta input. It is reset when `get_delta_and_reset` + # is called. + _new_appended_tokens: List[int] = msgspec.field(default_factory=list) + + def __post_init__(self) -> None: + assert self._prompt_token_ids.typecode == "l" + assert self._output_token_ids.typecode == "l" + self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( + self._prompt_token_ids) self._update_cached_all_tokens() def _update_cached_all_tokens(self): + assert isinstance(self._prompt_token_ids, array) + assert isinstance(self._output_token_ids, array) self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + self._output_token_ids) + @property + def cumulative_logprob(self) -> float: + return self._cumulative_logprob + @property def prompt_token_ids(self) -> Tuple[int, ...]: return self._prompt_token_ids_tuple @prompt_token_ids.setter def prompt_token_ids(self, new_prompt_token_ids) -> None: - self._prompt_token_ids = array('l', new_prompt_token_ids) - self._prompt_token_ids_tuple = tuple(new_prompt_token_ids) - self._update_cached_all_tokens() + raise NotImplementedError @property def prompt_token_ids_array(self) -> array: + """Return the prompt token ids in array type. + + Note that the array is in "I" type, and it is not compatible + with torch.long (2 bytes vs 4 bytes). So beware of the usage. + """ return self._prompt_token_ids @property @@ -166,18 +205,26 @@ def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @output_token_ids.setter - def output_token_ids(self, new_output_token_ids) -> None: - self._output_token_ids = array('l', new_output_token_ids) + def output_token_ids(self, new_output_token_ids: List[int]) -> None: + self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + new_output_token_ids) self._update_cached_all_tokens() @property def output_token_ids_array(self) -> array: + """Return the prompt token ids in array type. + + Note that the array is in "I" type, and it is not compatible + with torch.long (2 bytes vs 4 bytes). So beware of the usage. + """ + assert isinstance(self._output_token_ids, array) return self._output_token_ids def append_token_id(self, token_id: int, logprob: float) -> None: self._output_token_ids.append(token_id) + self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) - self.cumulative_logprob += logprob + self._cumulative_logprob += logprob def get_len(self) -> int: return len(self._output_token_ids) + len(self._prompt_token_ids) @@ -222,6 +269,7 @@ def reset_state_for_recompute(self) -> None: """ self._num_computed_tokens = 0 self._stage = SequenceStage.PREFILL + self._new_appended_tokens = [] def get_num_uncomputed_tokens(self) -> int: """Return the number of prefill tokens that are not computed.""" @@ -241,6 +289,21 @@ def get_prompt_token_ids(self) -> Tuple[int, ...]: def get_output_token_ids(self) -> Tuple[int, ...]: return self.output_token_ids + def get_delta_and_reset(self) -> SequenceDataDelta: + delta = SequenceDataDelta(self._new_appended_tokens, + self._cumulative_logprob, + self.get_num_computed_tokens(), self.stage) + # Reset delta state. + self._new_appended_tokens = [] + return delta + + def apply_delta(self, delta: SequenceDataDelta): + self._num_computed_tokens = delta.new_num_computed_tokens + self._cumulative_logprob = delta.new_cumulative_logprob + self._stage = delta.new_stage + self._output_token_ids.extend(delta.new_output_token_ids) + self._cached_all_token_ids.extend(delta.new_output_token_ids) + @property def stage(self) -> SequenceStage: return self._stage @@ -248,8 +311,9 @@ def stage(self) -> SequenceStage: def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self._prompt_token_ids}, " - f"output_token_ids={self._output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob})") + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"get_num_computed_tokens={self.get_num_computed_tokens()}") class Sequence: @@ -325,7 +389,8 @@ def __init__( f"invalid input {inputs}; did you forget the " "encoder input prompt fields?") - self.data = SequenceData(self.prompt_token_ids) + self.data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids)) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -490,8 +555,8 @@ def __repr__(self) -> str: f"num_blocks={self.n_blocks}, ") -@dataclass -class SequenceGroupState: +class SequenceGroupState(msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] """Mutable state tied to a specific sequence group""" # for multi-step decoding @@ -647,14 +712,19 @@ def get_max_num_running_seqs(self) -> int: if self.sampling_params and self.sampling_params.use_beam_search: # For beam search, maximally there will always be `best_of` beam # candidates running in the future. - return self.sampling_params.best_of + best_of = self.sampling_params.best_of + assert isinstance(best_of, int) + return best_of else: - if (self.sampling_params - and self.sampling_params.best_of > self.num_seqs()): - # At prompt stage, the sequence group is not yet filled up - # and only have one sequence running. However, in the - # generation stage, we will have `best_of` sequences running. - return self.sampling_params.best_of + if self.sampling_params: + best_of = self.sampling_params.best_of + assert isinstance(best_of, int) + if best_of > self.num_seqs(): + # At prompt stage, the sequence group is not yet filled up + # and only have one sequence running. However, in the + # generation stage, we will have `best_of` sequences + # running. + return best_of # At sampling stages, return the number of actual sequences # that are not finished yet. return self.num_unfinished_seqs() @@ -757,7 +827,32 @@ def __repr__(self) -> str: f"num_seqs={len(self.seqs)})") -class SequenceGroupMetadata: +class SequenceGroupMetadataDelta( + msgspec.Struct, + tag=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """Delta of SequenceGroupMetadata. + + After sending the first SequenceGroupMetadata, vLLM scheduler + only sends delta to reduce the data payload size. + """ + seq_data_delta: Dict[int, SequenceDataDelta] + request_id: str + block_tables: Dict[int, List[int]] + is_prompt: bool + do_sample: bool = True + token_chunk_size: Optional[int] = None + computed_block_nums: Optional[List[int]] = None + state: Optional[SequenceGroupState] = msgspec.field( + default_factory=lambda: SequenceGroupState()) + + +class SequenceGroupMetadata( + msgspec.Struct, + tag=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] """Metadata for a sequence group. Used to create `AttentionMetadata`. Args: @@ -789,52 +884,39 @@ class SequenceGroupMetadata: prompt_adapter_request: Prompt Adapter request. """ - def __init__( - self, - request_id: str, - is_prompt: bool, - seq_data: Dict[int, SequenceData], - sampling_params: SamplingParams, - block_tables: Dict[int, List[int]], - do_sample: bool = True, - pooling_params: Optional[PoolingParams] = None, - token_chunk_size: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - computed_block_nums: Optional[List[int]] = None, - state: Optional[SequenceGroupState] = None, - multi_modal_data: Optional["MultiModalDataDict"] = None, - encoder_seq_data: Optional[SequenceData] = None, - cross_block_table: Optional[List[int]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> None: - self.request_id = request_id - self.is_prompt = is_prompt - self.seq_data = seq_data - self.sampling_params = sampling_params - self.block_tables = block_tables - self.pooling_params = pooling_params - self.lora_request = lora_request - self.prompt_adapter_request = prompt_adapter_request - self.computed_block_nums = computed_block_nums - self.multi_modal_data = multi_modal_data - self.state = SequenceGroupState() if state is None else state - self.encoder_seq_data = encoder_seq_data - self.cross_block_table = cross_block_table - self._token_chunk_size = token_chunk_size - self.do_sample = do_sample - - # The number of speculative tokens adopted in this request. - # None means specuative decoding is not used. - # Zero means speculative decoding is disabled for some reasons. - # TODO: We should maintain this states out of the sequence group. - self.num_speculative_tokens = None - - if seq_data is not None and self._token_chunk_size is None: - if is_prompt: - self._token_chunk_size = next(iter( - seq_data.values())).get_len() + request_id: str + is_prompt: bool + seq_data: Dict[int, SequenceData] + sampling_params: SamplingParams + block_tables: Dict[int, List[int]] + do_sample: bool = True + pooling_params: Optional[PoolingParams] = None + lora_request: Optional[LoRARequest] = None + computed_block_nums: Optional[List[int]] = None + state: Optional[SequenceGroupState] = msgspec.field( + default_factory=lambda: SequenceGroupState()) + # "MultiModalDataDict" types. We have to use Any due to msgspec + # doesn't allow to have union of 2 different dicts. + multi_modal_data: Optional[Any] = None + encoder_seq_data: Optional[SequenceData] = None + cross_block_table: Optional[List[int]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + token_chunk_size: Optional[int] = None + + ### Stateful fields that are lazily defined. ### + # The number of speculative tokens adopted in this request. + # None means specuative decoding is not used. + # Zero means speculative decoding is disabled for some reasons. + # TODO: We should maintain this states out of the sequence group. + num_speculative_tokens: Optional[int] = None + + def __post_init__(self): + if self.seq_data is not None and self.token_chunk_size is None: + if self.is_prompt: + self.token_chunk_size = next(iter( + self.seq_data.values())).get_len() else: - self._token_chunk_size = 1 + self.token_chunk_size = 1 @property def lora_int_id(self) -> int: @@ -850,18 +932,26 @@ def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ if self.prompt_adapter_request else 0 - @property - def token_chunk_size(self) -> int: - """Return the number of tokens to be processed (chunk size).""" - assert self._token_chunk_size is not None - return self._token_chunk_size + def apply_delta(self, + sequence_group_metadata_delta: SequenceGroupMetadataDelta): + for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): + self.seq_data[id].apply_delta(delta) + assert self.request_id == sequence_group_metadata_delta.request_id + self.block_tables = sequence_group_metadata_delta.block_tables + self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size + self.do_sample = sequence_group_metadata_delta.do_sample + self.is_prompt = sequence_group_metadata_delta.is_prompt def finish_step(self) -> None: + assert self.state is not None assert self.state.current_step < self.state.num_steps self.state.current_step += 1 -class SequenceOutput: +class SequenceOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """The model output associated with a sequence. Args: @@ -871,16 +961,9 @@ class SequenceOutput: logprobs: The logprobs of the output token. (Token id -> logP(x_i+1 | x_0, ..., x_i)) """ - - def __init__( - self, - parent_seq_id: int, - output_token: int, - logprobs: Dict[int, Logprob], - ) -> None: - self.parent_seq_id = parent_seq_id - self.output_token = output_token - self.logprobs = logprobs + parent_seq_id: int + output_token: int + logprobs: Dict[int, Logprob] def __repr__(self) -> str: return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " @@ -908,17 +991,15 @@ def __eq__(self, other: object) -> bool: pass -class CompletionSequenceGroupOutput(SequenceGroupOutput): +class CompletionSequenceGroupOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + __metaclass__ = SequenceGroupOutput """The model output associated with a completion sequence group.""" - - def __init__( - self, - samples: List[SequenceOutput], - prompt_logprobs: Optional[PromptLogprobs], - ) -> None: - self.samples = samples - # Prompt logprob for each prompt query token. - self.prompt_logprobs = prompt_logprobs + samples: List[SequenceOutput] + # Prompt logprob for each prompt query token. + prompt_logprobs: Optional[PromptLogprobs] def __repr__(self) -> str: return (f"CompletionSequenceGroupOutput(samples={self.samples}, " @@ -931,14 +1012,14 @@ def __eq__(self, other: object) -> bool: and self.prompt_logprobs == other.prompt_logprobs) -class EmbeddingSequenceGroupOutput(SequenceGroupOutput): +class EmbeddingSequenceGroupOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] +): """The model output associated with an embedding sequence group.""" - - def __init__( - self, - embeddings: List[float], - ) -> None: - self.embeddings = embeddings + __metaclass__ = SequenceGroupOutput + embeddings: List[int] def __repr__(self) -> str: return (f"EmbeddingSequenceGroupOutput(" @@ -950,8 +1031,10 @@ def __eq__(self, other: object) -> bool: return self.embeddings == other.embeddings -@dataclass -class IntermediateTensors: +class IntermediateTensors( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. @@ -978,8 +1061,10 @@ def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" -@dataclass -class SamplerOutput: +class SamplerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """For each sequence group, we generate a list of SequenceOutput object, each of which contains one possible candidate for the next token. @@ -1000,7 +1085,7 @@ class SamplerOutput: sampled_token_ids_numpy: Optional[numpy.ndarray] = None # Spec decode metrics populated by workers. - spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None + spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None # Optional last hidden states from the model. hidden_states: Optional[torch.Tensor] = None @@ -1039,12 +1124,14 @@ def __repr__(self) -> str: f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") -@dataclass -class PoolerOutput: +class PoolerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """The output from a pooling operation in the embedding model.""" outputs: List[EmbeddingSequenceGroupOutput] - spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None + spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None def __getitem__(self, idx: int): return self.outputs[idx] @@ -1083,7 +1170,8 @@ def get_all_seq_ids_and_request_ids( return seq_ids, request_id_seq_ids_mapping -class HiddenStates: +class HiddenStates(msgspec.Struct, array_like=True, + omit_defaults=True): # type: ignore[call-arg] """Hidden states corresponding to in-progress sequences. Used in speculative decoding to pass hidden states from the target model to the proposer model in the subsequent step. @@ -1091,42 +1179,53 @@ class HiddenStates: seq_ids are the sequence ids of each entry of the batch dimension of the hidden_states tensor""" - def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata], - hidden_states: torch.Tensor): - assert len(seq_group_metadata_list) == len(hidden_states) - self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list) - self.hidden_states: torch.Tensor = hidden_states + seq_group_metadata_list: List[SequenceGroupMetadata] + hidden_states: torch.Tensor + _seq_ids: List[int] = msgspec.field(default_factory=list) + + def __post_init__(self): + self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) + assert len(self.seq_group_metadata_list) == len(self.hidden_states) + + @property + def seq_ids(self) -> List[int]: + return self._seq_ids def update(self, seq_group_metadata_list: List[SequenceGroupMetadata], hidden_states: torch.Tensor) -> None: """Update hidden states from target model invocation.""" assert len(seq_group_metadata_list) == len(hidden_states) - self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) + self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) def prune(self, seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: """Prune to provided list of sequence ids.""" seq_ids = get_all_seq_ids(seq_group_metadata_list) - if seq_ids != self.seq_ids: + if seq_ids != self._seq_ids: # Batch contents changed - prune removed sequences. - index = [self.seq_ids.index(seq_id) for seq_id in seq_ids] + index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] self.hidden_states = self.hidden_states[index] - self.seq_ids = seq_ids + self._seq_ids = seq_ids -@dataclass -class ExecuteModelRequest: +class ExecuteModelRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] """The model execution request, containing CPU metadata only. The LLM engine should create an instance of this class for each request batch.""" # The sequence group metadata list. - seq_group_metadata_list: List[SequenceGroupMetadata] + seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]] # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list) + blocks_to_swap_in: List[Tuple[int, + int]] = msgspec.field(default_factory=list) # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) + blocks_to_swap_out: List[Tuple[int, + int]] = msgspec.field(default_factory=list) # Blocks to copy. Source to dest block. - blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) + blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list) # Virtual engine ID for pipeline parallel. virtual_engine: int = 0 # The number of slots for lookahead decoding. @@ -1138,7 +1237,7 @@ class ExecuteModelRequest: # The number of forward steps to run. num_steps: int = 1 # Finished request ids since last step. - finished_requests_ids: List[str] = field(default_factory=list) + finished_requests_ids: List[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. last_sampled_token_ids: Optional[torch.Tensor] = None @@ -1148,6 +1247,7 @@ def is_first_multi_step(self) -> bool: # steps assert len(self.seq_group_metadata_list) > 0 first_seq_group = self.seq_group_metadata_list[0] + assert first_seq_group.state is not None return first_seq_group.state.current_step == 0 @property @@ -1156,6 +1256,7 @@ def is_last_step(self) -> bool: # steps assert len(self.seq_group_metadata_list) > 0 first_seq_group = self.seq_group_metadata_list[0] + assert first_seq_group.state is not None num_steps = first_seq_group.state.num_steps current_step = first_seq_group.state.current_step return num_steps - current_step == 1 @@ -1165,10 +1266,13 @@ def current_step(self) -> int: # TODO(will) make this be able to handle batches with variable number of # steps assert len(self.seq_group_metadata_list) > 0 - return self.seq_group_metadata_list[0].state.current_step + state = self.seq_group_metadata_list[0].state + assert state is not None + return state.current_step def clone( - self, seq_group_metadata_list: List[SequenceGroupMetadata] + self, seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]] ) -> "ExecuteModelRequest": """Clone the request with a new sequence group metadata list.""" return ExecuteModelRequest( diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 45eaeb51c5c0f..aec4847b96c35 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -1,11 +1,13 @@ +from array import array from itertools import chain, count from typing import Iterator, List, Tuple import torch from vllm import SamplingParams -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, - SequenceGroupMetadata, get_all_seq_ids) +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, + SamplerOutput, SequenceData, SequenceGroupMetadata, + get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, @@ -293,14 +295,15 @@ def _create_single_target_seq_group_metadata( input sequence. """ seq_data = seq_group_metadata.seq_data[seq_id] - prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_token_ids = seq_data.prompt_token_ids_array new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] new_seq_data_dict = { target_seq_id: SequenceData( - prompt_token_ids=prompt_token_ids, - output_token_ids=new_output_token_ids, + prompt_token_ids, + _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, + new_output_token_ids), ), } # This is a hack. Technically, spec decoding should compute diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 9036d117041f0..ad4e2dc879d7b 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -1,7 +1,7 @@ import time -from dataclasses import dataclass from typing import Callable, Optional +import msgspec import torch from vllm.model_executor.layers.spec_decode_base_sampler import ( @@ -9,8 +9,10 @@ from vllm.utils import is_pin_memory_available -@dataclass -class SpecDecodeWorkerMetrics: +class SpecDecodeWorkerMetrics( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """Dataclass holding metrics emitted from the spec decode worker. """ diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 8f4372e20d2e7..4e843a5d94372 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import List, Optional, Set, Tuple, Type +from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch import torch.distributed @@ -18,7 +18,9 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SamplerOutput, SequenceGroupMetadata, + SequenceGroupMetadataDelta) from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner @@ -112,6 +114,7 @@ def __init__( self.cache_engine: List[CacheEngine] # Initialize gpu_cache as embedding models don't initialize kv_caches self.gpu_cache: Optional[List[List[torch.Tensor]]] = None + self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} def _is_encoder_decoder_model(self): return self.model_config.is_encoder_decoder_model @@ -306,6 +309,63 @@ def execute_worker(self, worker_input: WorkerInput) -> None: and worker_input.blocks_to_copy.numel() > 0): self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) + def _get_cached_seq_group_metadata( + self, + seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]], + finished_request_ids: List[str]) -> List[SequenceGroupMetadata]: + """Return a list of cached Sequence Group Metadata after updating its + state. + + It is used because scheduler only sends delta to workers to reduce + the data payload size. The function also cleans up cache based on + a given `finished_request_ids`. + """ + new_seq_group_metadata_list = [] + for metadata_or_delta in seq_group_metadata_list: + request_id = metadata_or_delta.request_id + if request_id not in self._seq_group_metadata_cache: + # The first prefill. + assert isinstance(metadata_or_delta, SequenceGroupMetadata) + self._seq_group_metadata_cache[request_id] = metadata_or_delta + else: + # The first prefill is already cached. + if isinstance(metadata_or_delta, SequenceGroupMetadataDelta): + self._seq_group_metadata_cache[request_id].apply_delta( + metadata_or_delta) + else: + # If metadata snapshot is sent again, it is + # preempted. Reset the cache because we need to start + # from scratch. + assert isinstance(metadata_or_delta, SequenceGroupMetadata) + self._seq_group_metadata_cache[ + request_id] = metadata_or_delta + + new_seq_group_metadata_list.append( + self._seq_group_metadata_cache[request_id]) + + # Clean up finished ids + for finished_id in finished_request_ids: + del self._seq_group_metadata_cache[finished_id] + + return new_seq_group_metadata_list + + def _execute_model_spmd( + self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Optional[List[SamplerOutput]]: + if execute_model_req is not None: + new_seq_group_metadata_list = self._get_cached_seq_group_metadata( + execute_model_req.seq_group_metadata_list, + execute_model_req.finished_requests_ids) + + execute_model_req.seq_group_metadata_list = ( + new_seq_group_metadata_list) + output = super()._execute_model_spmd(execute_model_req, + intermediate_tensors) + return output + def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request)