Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Optimize SPMD architecture with delta + serialization optimization #7109

Merged
merged 48 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d41f4c5
wip
rkooo567 Jul 25, 2024
5741a83
fix original arch issue
rkooo567 Jul 25, 2024
d31d73f
should work now.
rkooo567 Jul 25, 2024
36e786d
working
rkooo567 Jul 25, 2024
71e40c1
.
rkooo567 Jul 25, 2024
7e69242
pickle
rkooo567 Jul 25, 2024
0de9f23
msgpack optimization
rkooo567 Jul 27, 2024
64faf75
Merge branch 'main' into serialization-opt
rkooo567 Jul 29, 2024
de4e43e
ip
rkooo567 Jul 30, 2024
dc7c445
.
rkooo567 Jul 30, 2024
700e4a3
Merge branch 'main' into serialization-opt
rkooo567 Jul 30, 2024
a906a9d
msgspec migration done
rkooo567 Jul 31, 2024
4af6699
ip. preemption and chunked prefill not working yet.
rkooo567 Aug 1, 2024
1e6196b
working e2e
rkooo567 Aug 3, 2024
0ea6e41
Merge branch 'main-before-server' into spmd-and-pp
rkooo567 Aug 3, 2024
35e9637
working finally
rkooo567 Aug 3, 2024
912b88b
.
rkooo567 Aug 5, 2024
5bab192
working
rkooo567 Aug 5, 2024
eb2cb14
working
rkooo567 Aug 5, 2024
007fe86
fix a test failure.
rkooo567 Aug 5, 2024
ce64b8d
.
rkooo567 Aug 7, 2024
e8e29e1
fixed
rkooo567 Aug 10, 2024
751bdb1
addressed code review.
rkooo567 Aug 12, 2024
d91aa78
lint
rkooo567 Aug 12, 2024
06774d1
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 12, 2024
1af8dc2
ip
rkooo567 Aug 12, 2024
6e6ac92
all working
rkooo567 Aug 12, 2024
fa0d077
lint
rkooo567 Aug 12, 2024
b5a88ec
done
rkooo567 Aug 12, 2024
d2e14ca
code review.
rkooo567 Aug 12, 2024
8be3c8e
addressed code review.
rkooo567 Aug 13, 2024
c42c6c5
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 13, 2024
c55c8f6
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 13, 2024
2ba99e2
lint fix
rkooo567 Aug 13, 2024
e2c850b
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 14, 2024
41ec6d1
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 14, 2024
925c928
fix lint
rkooo567 Aug 14, 2024
9d3dee5
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 15, 2024
d041e9c
Addressed code review.
rkooo567 Aug 15, 2024
c4b3682
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 16, 2024
f938e00
fix pydantic not compatible to msggspec.Struct.
rkooo567 Aug 17, 2024
32cb984
addressed
rkooo567 Aug 17, 2024
5a4f27e
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 17, 2024
c921877
fixed
rkooo567 Aug 17, 2024
ae1fb21
temporarily use dataclass
rkooo567 Aug 17, 2024
c3abcc5
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 17, 2024
3e1325e
Addressed code review.
rkooo567 Aug 18, 2024
652c258
lint
rkooo567 Aug 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
msgspec
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember I once tried this library, but its serialization scopt is quite limited. Some classes cannot be serialized via this library. Do you have this experience when use it in ray?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we used this library for our internal fork before. For this PR, I have to implement custom reduce for array type. And Union of the same 2 types are not supported (e.g., Union[OrderedDict, Dict] kind of thing). But I think we can get around this much

librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
Expand Down
18 changes: 18 additions & 0 deletions tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from prometheus_client import REGISTRY

import vllm.envs as envs
from vllm import SamplingParams
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT)
Expand All @@ -24,6 +25,13 @@
"tests/basic_correctness/test_preemption.py`")


@pytest.fixture
def worker_use_ray() -> bool:
# When SPMD worker is used, use ray_use_worker=True
# to test delta input optimization works with preemption.
return envs.VLLM_USE_RAY_SPMD_WORKER
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
Expand All @@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
worker_use_ray: bool,
) -> None:
"""Ensure that chunked prefill works with preemption."""
max_num_seqs = min(chunked_prefill_token_size, 256)
Expand All @@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
max_num_seqs=max_num_seqs,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
Expand All @@ -79,6 +89,7 @@ def test_preemption(
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:
"""By default, recompute preemption is enabled"""

Expand All @@ -89,6 +100,7 @@ def test_preemption(
model,
dtype=dtype,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
Expand Down Expand Up @@ -132,6 +144,7 @@ def test_swap(
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Use beam search enables swapping."""
example_prompts = example_prompts[:1]
Expand All @@ -144,6 +157,7 @@ def test_swap(
dtype=dtype,
swap_space=10,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
Expand Down Expand Up @@ -188,6 +202,7 @@ def test_swap_infeasible(
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16
Expand All @@ -204,6 +219,7 @@ def test_swap_infeasible(
# decode blocks are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray,
) as vllm_model:
sampling_params = SamplingParams(n=beam_width,
use_beam_search=True,
Expand All @@ -230,6 +246,7 @@ def test_preemption_infeasible(
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:
"""Verify infeasible preemption request will be ignored."""
BLOCK_SIZE = 16
Expand All @@ -244,6 +261,7 @@ def test_preemption_infeasible(
# ignored instead of hanging forever.
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
worker_use_ray=worker_use_ray,
) as vllm_model:
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True)
Expand Down
33 changes: 33 additions & 0 deletions tests/core/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import msgspec

from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.sequence import ExecuteModelRequest

from ..spec_decode.utils import create_batch


def test_msgspec_serialization():
num_lookahead_slots = 4
seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=num_lookahead_slots,
running_queue_size=4)

encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=decode_hook)
req = decoder.decode(encoder.encode(execute_model_req))
expected = execute_model_req.seq_group_metadata_list
actual = req.seq_group_metadata_list
assert (len(expected) == len(actual))
expected = expected[0]
actual = actual[0]

assert expected.block_tables == actual.block_tables
assert expected.is_prompt == actual.is_prompt
assert expected.request_id == actual.request_id
assert (expected.seq_data[0].prompt_token_ids ==
actual.seq_data[0].prompt_token_ids)
assert (expected.seq_data[0].output_token_ids ==
actual.seq_data[0].output_token_ids)
4 changes: 3 additions & 1 deletion tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"model, distributed_executor_backend, attention_backend, test_suite", [
"model, distributed_executor_backend, attention_backend, "
"test_suite", [
("facebook/opt-125m", "ray", "", "L4"),
("facebook/opt-125m", "ray", "", "L4"),
("facebook/opt-125m", "mp", "", "L4"),
("meta-llama/Llama-2-7b-hf", "ray", "", "L4"),
Expand Down
8 changes: 8 additions & 0 deletions tests/distributed/test_chunked_prefill_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
```
"""

import os

import pytest

from vllm.utils import cuda_device_count_stateless
Expand All @@ -17,6 +19,7 @@
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model, distributed_executor_backend", [
("facebook/opt-125m", "ray"),
("facebook/opt-125m", "ray"),
("meta-llama/Llama-2-7b-hf", "ray"),
("facebook/opt-125m", "mp"),
Expand All @@ -30,6 +33,11 @@ def test_models(
model: str,
distributed_executor_backend: str,
) -> None:
if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray": # noqa
assert distributed_executor_backend == "ray"
# test ray adag
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"

dtype = "half"
max_tokens = 5
Expand Down
25 changes: 19 additions & 6 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import random
from array import array
from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch

Expand All @@ -10,7 +11,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import Counter, is_pin_memory_available


Expand Down Expand Up @@ -56,7 +58,9 @@ def _do_sample(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params,
block_tables={0: [1]},
))
Expand Down Expand Up @@ -201,7 +205,8 @@ def create_sampling_params(min_tokens,

def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData(
random.choices(range(0, VOCAB_SIZE), k=num_input))
array(VLLM_TOKEN_ID_ARRAY_TYPE,
random.choices(range(0, VOCAB_SIZE), k=num_input)))
if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated)
Expand Down Expand Up @@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params,
block_tables={0: [1]},
))
Expand Down Expand Up @@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams(
temperature=1,
top_k=top_k,
Expand Down Expand Up @@ -650,7 +659,11 @@ def test_sampling_params(sampling_params: List[SamplingParams]):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0:
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
[1, 2, 3]))
},
sampling_params=sampling_params[i],
block_tables={0: [1]},
))
Expand Down
9 changes: 6 additions & 3 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from array import array
from itertools import count
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
Expand All @@ -9,7 +10,8 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
Expand Down Expand Up @@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
seq_data={
i:
SequenceData(
prompt_token_ids=prompt_token_ids[:],
output_token_ids=cont_token_ids[:],
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
cont_token_ids[:]),
),
},
sampling_params=SamplingParams(temperature=0.0, ),
Expand Down
8 changes: 6 additions & 2 deletions tests/test_logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from array import array
from typing import Tuple
from unittest.mock import patch

Expand All @@ -8,7 +9,8 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -69,7 +71,9 @@ def pick_ith(token_ids, logits):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
Expand Down
7 changes: 5 additions & 2 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from array import array

import pytest

from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput,
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SamplerOutput,
SequenceData, SequenceOutput)

from .core.utils import create_dummy_prompt
Expand Down Expand Up @@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):


def test_sequence_data_prefill():
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0
# advance by 2
Expand Down
16 changes: 11 additions & 5 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from array import array
from typing import List

import pytest
import torch

from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_cpu
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner

Expand Down Expand Up @@ -125,10 +127,12 @@ def test_prepare_prompt(
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand Down Expand Up @@ -319,10 +323,12 @@ def test_prepare_decode(
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
Expand Down
Loading
Loading