Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Optimize SPMD architecture with delta + serialization optimization #7109

Merged
merged 48 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d41f4c5
wip
rkooo567 Jul 25, 2024
5741a83
fix original arch issue
rkooo567 Jul 25, 2024
d31d73f
should work now.
rkooo567 Jul 25, 2024
36e786d
working
rkooo567 Jul 25, 2024
71e40c1
.
rkooo567 Jul 25, 2024
7e69242
pickle
rkooo567 Jul 25, 2024
0de9f23
msgpack optimization
rkooo567 Jul 27, 2024
64faf75
Merge branch 'main' into serialization-opt
rkooo567 Jul 29, 2024
de4e43e
ip
rkooo567 Jul 30, 2024
dc7c445
.
rkooo567 Jul 30, 2024
700e4a3
Merge branch 'main' into serialization-opt
rkooo567 Jul 30, 2024
a906a9d
msgspec migration done
rkooo567 Jul 31, 2024
4af6699
ip. preemption and chunked prefill not working yet.
rkooo567 Aug 1, 2024
1e6196b
working e2e
rkooo567 Aug 3, 2024
0ea6e41
Merge branch 'main-before-server' into spmd-and-pp
rkooo567 Aug 3, 2024
35e9637
working finally
rkooo567 Aug 3, 2024
912b88b
.
rkooo567 Aug 5, 2024
5bab192
working
rkooo567 Aug 5, 2024
eb2cb14
working
rkooo567 Aug 5, 2024
007fe86
fix a test failure.
rkooo567 Aug 5, 2024
ce64b8d
.
rkooo567 Aug 7, 2024
e8e29e1
fixed
rkooo567 Aug 10, 2024
751bdb1
addressed code review.
rkooo567 Aug 12, 2024
d91aa78
lint
rkooo567 Aug 12, 2024
06774d1
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 12, 2024
1af8dc2
ip
rkooo567 Aug 12, 2024
6e6ac92
all working
rkooo567 Aug 12, 2024
fa0d077
lint
rkooo567 Aug 12, 2024
b5a88ec
done
rkooo567 Aug 12, 2024
d2e14ca
code review.
rkooo567 Aug 12, 2024
8be3c8e
addressed code review.
rkooo567 Aug 13, 2024
c42c6c5
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 13, 2024
c55c8f6
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 13, 2024
2ba99e2
lint fix
rkooo567 Aug 13, 2024
e2c850b
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 14, 2024
41ec6d1
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 14, 2024
925c928
fix lint
rkooo567 Aug 14, 2024
9d3dee5
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 15, 2024
d041e9c
Addressed code review.
rkooo567 Aug 15, 2024
c4b3682
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 16, 2024
f938e00
fix pydantic not compatible to msggspec.Struct.
rkooo567 Aug 17, 2024
32cb984
addressed
rkooo567 Aug 17, 2024
5a4f27e
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 17, 2024
c921877
fixed
rkooo567 Aug 17, 2024
ae1fb21
temporarily use dataclass
rkooo567 Aug 17, 2024
c3abcc5
Merge branch 'main' into spmd-and-pp
rkooo567 Aug 17, 2024
3e1325e
Addressed code review.
rkooo567 Aug 18, 2024
652c258
lint
rkooo567 Aug 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
msgspec
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember I once tried this library, but its serialization scopt is quite limited. Some classes cannot be serialized via this library. Do you have this experience when use it in ray?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we used this library for our internal fork before. For this PR, I have to implement custom reduce for array type. And Union of the same 2 types are not supported (e.g., Union[OrderedDict, Dict] kind of thing). But I think we can get around this much

gguf == 0.9.1
18 changes: 18 additions & 0 deletions tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from prometheus_client import REGISTRY

import vllm.envs as envs
from vllm import SamplingParams
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT)
Expand All @@ -24,6 +25,13 @@
"tests/basic_correctness/test_preemption.py`")


@pytest.fixture
def worker_use_ray() -> bool:
# When SPMD worker is used, use ray_use_worker=True
# to test delta input optimization works with preemption.
return envs.VLLM_USE_RAY_SPMD_WORKER
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
Expand All @@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
worker_use_ray: bool,
) -> None:
"""Ensure that chunked prefill works with preemption."""
max_num_seqs = min(chunked_prefill_token_size, 256)
Expand All @@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
max_num_seqs=max_num_seqs,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
Expand All @@ -79,6 +89,7 @@ def test_preemption(
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:
"""By default, recompute preemption is enabled"""

Expand All @@ -89,6 +100,7 @@ def test_preemption(
model,
dtype=dtype,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
Expand Down Expand Up @@ -132,6 +144,7 @@ def test_swap(
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Use beam search enables swapping."""
example_prompts = example_prompts[:1]
Expand All @@ -144,6 +157,7 @@ def test_swap(
dtype=dtype,
swap_space=10,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
Expand Down Expand Up @@ -188,6 +202,7 @@ def test_swap_infeasible(
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16
Expand All @@ -204,6 +219,7 @@ def test_swap_infeasible(
# decode blocks are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray,
) as vllm_model:
sampling_params = SamplingParams(n=beam_width,
use_beam_search=True,
Expand All @@ -230,6 +246,7 @@ def test_preemption_infeasible(
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:
"""Verify infeasible preemption request will be ignored."""
BLOCK_SIZE = 16
Expand All @@ -244,6 +261,7 @@ def test_preemption_infeasible(
# ignored instead of hanging forever.
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
worker_use_ray=worker_use_ray,
) as vllm_model:
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True)
Expand Down
33 changes: 33 additions & 0 deletions tests/core/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import msgspec

from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.sequence import ExecuteModelRequest

from ..spec_decode.utils import create_batch


def test_msgspec_serialization():
num_lookahead_slots = 4
seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=num_lookahead_slots,
running_queue_size=4)

encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=decode_hook)
req = decoder.decode(encoder.encode(execute_model_req))
expected = execute_model_req.seq_group_metadata_list
actual = req.seq_group_metadata_list
assert (len(expected) == len(actual))
expected = expected[0]
actual = actual[0]

assert expected.block_tables == actual.block_tables
assert expected.is_prompt == actual.is_prompt
assert expected.request_id == actual.request_id
assert (expected.seq_data[0].prompt_token_ids ==
actual.seq_data[0].prompt_token_ids)
assert (expected.seq_data[0].output_token_ids ==
actual.seq_data[0].output_token_ids)
23 changes: 13 additions & 10 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"model, distributed_executor_backend, attention_backend, test_suite", [
("facebook/opt-125m", "ray", "", "L4"),
("facebook/opt-125m", "mp", "", "L4"),
("meta-llama/Llama-2-7b-hf", "ray", "", "L4"),
("meta-llama/Llama-2-7b-hf", "mp", "", "L4"),
("facebook/opt-125m", "ray", "", "A100"),
("facebook/opt-125m", "mp", "", "A100"),
("facebook/opt-125m", "mp", "FLASHINFER", "A100"),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
"model, distributed_executor_backend, attention_backend, "
"test_suite, enable_adag", [
("facebook/opt-125m", "ray", "", "L4", False),
("facebook/opt-125m", "ray", "", "L4", True),
("facebook/opt-125m", "mp", "", "L4", False),
("meta-llama/Llama-2-7b-hf", "ray", "", "L4", False),
("meta-llama/Llama-2-7b-hf", "mp", "", "L4", False),
("facebook/opt-125m", "ray", "", "A100", False),
("facebook/opt-125m", "mp", "", "A100", False),
("facebook/opt-125m", "mp", "FLASHINFER", "A100", False),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100", False),
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
])
@fork_new_process_for_each_test
def test_models(
Expand All @@ -41,12 +43,13 @@ def test_models(
distributed_executor_backend: str,
attention_backend: str,
test_suite: str,
enable_adag: bool,
) -> None:

if test_suite != TARGET_TEST_SUITE:
pytest.skip(f"Skip test for {test_suite}")

if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
if enable_adag:
# test ray adag
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"
Expand Down
19 changes: 14 additions & 5 deletions tests/distributed/test_chunked_prefill_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
```
"""

import os

import pytest

from vllm.utils import cuda_device_count_stateless
Expand All @@ -16,11 +18,12 @@

@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model, distributed_executor_backend", [
("facebook/opt-125m", "ray"),
("meta-llama/Llama-2-7b-hf", "ray"),
("facebook/opt-125m", "mp"),
("meta-llama/Llama-2-7b-hf", "mp"),
@pytest.mark.parametrize("model, distributed_executor_backend, enable_adag", [
("facebook/opt-125m", "ray", False),
("facebook/opt-125m", "ray", True),
("meta-llama/Llama-2-7b-hf", "ray", False),
("facebook/opt-125m", "mp", False),
("meta-llama/Llama-2-7b-hf", "mp", False),
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
])
@fork_new_process_for_each_test
def test_models(
Expand All @@ -29,7 +32,13 @@ def test_models(
example_prompts,
model: str,
distributed_executor_backend: str,
enable_adag: bool,
) -> None:
if enable_adag:
assert distributed_executor_backend == "ray"
# test ray adag
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"

dtype = "half"
max_tokens = 5
Expand Down
11 changes: 6 additions & 5 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import random
from array import array
from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch

Expand Down Expand Up @@ -56,7 +57,7 @@ def _do_sample(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={0: SequenceData(array("I", [1, 2, 3]))},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also accept list input, and do the convert inside SequenceData constructor?

I'd like to reduce unnecessary files touched by this PR.

Copy link
Collaborator Author

@rkooo567 rkooo567 Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is I found msgspec doesn't allow me to have input of Union[List, array] (it cannot serialize union with the same type. Other example is sth like Union[OrderedDict, dict]). Technically we can just implicitly convert all list though (I think the type there is just a hint) if you prefer that way

sampling_params=sampling_params,
block_tables={0: [1]},
))
Expand Down Expand Up @@ -201,7 +202,7 @@ def create_sampling_params(min_tokens,

def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData(
random.choices(range(0, VOCAB_SIZE), k=num_input))
array("I", random.choices(range(0, VOCAB_SIZE), k=num_input)))
if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated)
Expand Down Expand Up @@ -504,7 +505,7 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={0: SequenceData(array("I", [1, 2, 3]))},
sampling_params=sampling_params,
block_tables={0: [1]},
))
Expand Down Expand Up @@ -600,7 +601,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={0: SequenceData(array("I", [1, 2, 3]))},
sampling_params=SamplingParams(
temperature=1,
top_k=top_k,
Expand Down Expand Up @@ -650,7 +651,7 @@ def test_sampling_params(sampling_params: List[SamplingParams]):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={0: SequenceData(array("I", [1, 2, 3]))},
sampling_params=sampling_params[i],
block_tables={0: [1]},
))
Expand Down
5 changes: 3 additions & 2 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from array import array
from itertools import count
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
Expand Down Expand Up @@ -138,8 +139,8 @@ def create_seq_group_metadata_from_prompts(
seq_data={
i:
SequenceData(
prompt_token_ids=prompt_token_ids[:],
output_token_ids=cont_token_ids[:],
array("I", prompt_token_ids[:]),
_output_token_ids=array("I", cont_token_ids[:]),
),
},
sampling_params=SamplingParams(temperature=0.0, ),
Expand Down
3 changes: 2 additions & 1 deletion tests/test_logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from array import array
from typing import Tuple
from unittest.mock import patch

Expand Down Expand Up @@ -69,7 +70,7 @@ def pick_ith(token_ids, logits):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={0: SequenceData(array("I", [1, 2, 3]))},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
Expand Down
4 changes: 3 additions & 1 deletion tests/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from array import array

import pytest

from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput,
Expand Down Expand Up @@ -54,7 +56,7 @@ def test_sampler_output_eq(sample_outputs):


def test_sequence_data_prefill():
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
seq_data = SequenceData(array("I", [1, 2, 3, 4]))
assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0
# advance by 2
Expand Down
9 changes: 5 additions & 4 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from array import array
from typing import List

import pytest
Expand Down Expand Up @@ -125,10 +126,10 @@ def test_prepare_prompt(
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(array("I", range(seq_len)))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
encoder_seq_data = SequenceData(array("I", range(encoder_seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand Down Expand Up @@ -319,10 +320,10 @@ def test_prepare_decode(
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(array("I", (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
encoder_seq_data = SequenceData(array("I", (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
Expand Down
9 changes: 5 additions & 4 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from array import array
from typing import List

import pytest
Expand Down Expand Up @@ -46,7 +47,7 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(array("I", range(seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand Down Expand Up @@ -163,7 +164,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = SequenceData(list(range(context_len)))
seq_data = SequenceData(array("I", range(context_len)))
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
Expand Down Expand Up @@ -324,7 +325,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(array("I", range(seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand All @@ -340,7 +341,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(context_len))
prompt_toks = array("I", range(context_len))
seq_data = SequenceData(prompt_toks)
seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len)
Expand Down
Loading
Loading