-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Optimize SPMD architecture with delta + serialization optimization #7109
Changes from 32 commits
d41f4c5
5741a83
d31d73f
36e786d
71e40c1
7e69242
0de9f23
64faf75
de4e43e
dc7c445
700e4a3
a906a9d
4af6699
1e6196b
0ea6e41
35e9637
912b88b
5bab192
eb2cb14
007fe86
ce64b8d
e8e29e1
751bdb1
d91aa78
06774d1
1af8dc2
6e6ac92
fa0d077
b5a88ec
d2e14ca
8be3c8e
c42c6c5
c55c8f6
2ba99e2
e2c850b
41ec6d1
925c928
9d3dee5
d041e9c
c4b3682
f938e00
32cb984
5a4f27e
c921877
ae1fb21
c3abcc5
3e1325e
652c258
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import msgspec | ||
|
||
from vllm.executor.msgspec_utils import decode_hook, encode_hook | ||
from vllm.sequence import ExecuteModelRequest | ||
|
||
from ..spec_decode.utils import create_batch | ||
|
||
|
||
def test_msgspec_serialization(): | ||
num_lookahead_slots = 4 | ||
seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots) | ||
execute_model_req = ExecuteModelRequest( | ||
seq_group_metadata_list=seq_group_metadata_list, | ||
num_lookahead_slots=num_lookahead_slots, | ||
running_queue_size=4) | ||
|
||
encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) | ||
decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, | ||
dec_hook=decode_hook) | ||
req = decoder.decode(encoder.encode(execute_model_req)) | ||
expected = execute_model_req.seq_group_metadata_list | ||
actual = req.seq_group_metadata_list | ||
assert (len(expected) == len(actual)) | ||
expected = expected[0] | ||
actual = actual[0] | ||
|
||
assert expected.block_tables == actual.block_tables | ||
assert expected.is_prompt == actual.is_prompt | ||
assert expected.request_id == actual.request_id | ||
assert (expected.seq_data[0].prompt_token_ids == | ||
actual.seq_data[0].prompt_token_ids) | ||
assert (expected.seq_data[0].output_token_ids == | ||
actual.seq_data[0].output_token_ids) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import itertools | ||
import random | ||
from array import array | ||
from typing import Dict, List, Optional, Tuple | ||
from unittest.mock import Mock, patch | ||
|
||
|
@@ -56,7 +57,7 @@ def _do_sample( | |
SequenceGroupMetadata( | ||
request_id=f"test_{i}", | ||
is_prompt=True, | ||
seq_data={0: SequenceData([1, 2, 3])}, | ||
seq_data={0: SequenceData(array("I", [1, 2, 3]))}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can also accept list input, and do the convert inside I'd like to reduce unnecessary files touched by this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue is I found msgspec doesn't allow me to have input of Union[List, array] (it cannot serialize union with the same type. Other example is sth like Union[OrderedDict, dict]). Technically we can just implicitly convert all list though (I think the type there is just a hint) if you prefer that way |
||
sampling_params=sampling_params, | ||
block_tables={0: [1]}, | ||
)) | ||
|
@@ -201,7 +202,7 @@ def create_sampling_params(min_tokens, | |
|
||
def create_sequence_data(num_input=3, num_generated=0): | ||
seq_data = SequenceData( | ||
random.choices(range(0, VOCAB_SIZE), k=num_input)) | ||
array("I", random.choices(range(0, VOCAB_SIZE), k=num_input))) | ||
if num_generated > 0: | ||
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), | ||
k=num_generated) | ||
|
@@ -504,7 +505,7 @@ def test_sampler_mixed(seed: int, device: str): | |
SequenceGroupMetadata( | ||
request_id=f"test_{i}", | ||
is_prompt=True, | ||
seq_data={0: SequenceData([1, 2, 3])}, | ||
seq_data={0: SequenceData(array("I", [1, 2, 3]))}, | ||
sampling_params=sampling_params, | ||
block_tables={0: [1]}, | ||
)) | ||
|
@@ -600,7 +601,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): | |
SequenceGroupMetadata( | ||
request_id=f"test_{i}", | ||
is_prompt=True, | ||
seq_data={0: SequenceData([1, 2, 3])}, | ||
seq_data={0: SequenceData(array("I", [1, 2, 3]))}, | ||
sampling_params=SamplingParams( | ||
temperature=1, | ||
top_k=top_k, | ||
|
@@ -650,7 +651,7 @@ def test_sampling_params(sampling_params: List[SamplingParams]): | |
SequenceGroupMetadata( | ||
request_id=f"test_{i}", | ||
is_prompt=True, | ||
seq_data={0: SequenceData([1, 2, 3])}, | ||
seq_data={0: SequenceData(array("I", [1, 2, 3]))}, | ||
sampling_params=sampling_params[i], | ||
block_tables={0: [1]}, | ||
)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember I once tried this library, but its serialization scopt is quite limited. Some classes cannot be serialized via this library. Do you have this experience when use it in ray?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we used this library for our internal fork before. For this PR, I have to implement custom reduce for
array
type. And Union of the same 2 types are not supported (e.g., Union[OrderedDict, Dict] kind of thing). But I think we can get around this much