Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Optimize SPMD architecture with delta + serialization optimization #7109

Merged
merged 48 commits into from
Aug 19, 2024

Conversation

rkooo567
Copy link
Collaborator

@rkooo567 rkooo567 commented Aug 3, 2024

This PR is a part of performance roadmap #6801 and based on #6556.

Also it exactly implements #6241

I verified the PR works with

  • chunked prefill
  • preemption
  • PP + TP + ray aDAG
  • seed (we moved this to worker already thanksfully!)

It doesn't currently work with mp yet.

More details about the PR;

  • Send delta input for SPMD. Only SequenceData is a state, so we find delta from SequenceData.
  • Optimize serialization performance using msgspec. I found pickle is at least 2X faster than ray's default serialization, and msgspec is at least 2X faster than pickle.
  • Integration to PP. We also send delta to next stage workers other than the entire SequenceGroupMetadata.

The following items will go to follow-up PRs

  • [WIP] Spec decoding + SPMD
  • Skip prepare_inputs for PP stage > 0 (blocked due to aDAG issue)
  • Lora tests
  • guided decoding

Performance benchmark

  • For realistic online serving workload, I found performance match for high TP (e.g., llama 70B + tp 4). I will post the result soon.
  • For PP, we found it has much higher throughput

We are still running more benchmark.

  • [WIP] We are benchmark PP8 + a10g + throughput benchmark.
  • [WIP] We are going to benchmark 405B + PP 2 + TP 8 on 2 A100 nodes.

TP + small model + throughput benchmark

max_num_seqs=256, opt 125m, tp 2
This is the setup control plane overhead is most amplified. We found there's only 5% regression in this setup.

# after
Throughput: 20.12 requests/s, 23181.18 tokens/s
# before
Throughput: 21.43 requests/s, 24684.19 tokens/s

max_num_seqs=64, opt 125m, tp 2
We found at smaller batch size (which is more common for "real workloads"), SPMD is faster

# after
Throughput: 14.68 requests/s, 28297.32 tokens/s
# before
Throughput: 13.87 requests/s, 26739.96 tokens/s

PP throughput benchmark

We also ran PP benchmark with QPS 10, 1000 requests, PP4 + a10g + llama 3 8B + shareGPT dataset. We enabled chunked prefill to reduce bubbles (it uses OpenAI server). This was before including zeromq + decoupling OpenAI server from engine, so with that, the perf impact may be not as good as this result.

Default

┏━━━━━┳━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓
┃ Run ┃ qps ┃ server_name ┃ e2e_time_s ┃ error_rate ┃ e2e_throughput_request_per_s ┃ itl_ms_mean ┃ ttft_ms_mean ┃ request_e2e_s_mean ┃
┡━━━━━╇━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩
│   0 │ 10  │ local_vllm  │ 38.7       │ 0          │ 2.6                          │ 207.3       │ 422.3        │ 26.8               │
└─────┴─────┴─────────────┴────────────┴────────────┴──────────────────────────────┴─────────────┴──────────────┴────────────────────┘

SPMD

┏━━━━━┳━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓
┃ Run ┃ qps ┃ server_name ┃ e2e_time_s ┃ error_rate ┃ e2e_throughput_request_per_s ┃ itl_ms_mean ┃ ttft_ms_mean ┃ request_e2e_s_mean ┃
┡━━━━━╇━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩
│   0 │ 10  │ local_vllm  │ 30.2       │ 0          │ 3.3                          │ 155.4       │ 312.5        │ 19.1               │
└─────┴─────┴─────────────┴────────────┴────────────┴──────────────────────────────┴─────────────┴──────────────┴────────────────────┘

So we have about 30% throughput improvement (when chunked prefill is disabled, the diff was bigger)

PP Batch inference benchmark

Generally 10~15% throughput improvement for batch inference workloads (this doesn't use OpenAI server) with PP8 + L4 GPUs. It is the batch inference workloads we care at Anyscale.

L4 + PP8 + chunked
Throughput: 562.3082515078258 tokens/s
L4 + PP8 + not chunked
Throughput: 77.0155615308245 tokens/s
L4 + PP8 + chunked + SPMD + no nccl
Throughput: 610.663914303865 tokens/s
L4 + PP8 + chunked + SPMD + nccl
Throughput: 608.6000705018308 tokens/s

TP latency benchmark

WIP... (make sure it matches)

TP throughput benchmark

WIP... (make sure it matches)


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@rkooo567 rkooo567 changed the title SPMD + PP [WIP] SPMD + PP Aug 3, 2024
Copy link

github-actions bot commented Aug 3, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

vllm/worker/worker.py Outdated Show resolved Hide resolved
@rkooo567 rkooo567 changed the title [WIP] SPMD + PP [Core] Optimize SPMD architecture with delta + serialization optimization Aug 6, 2024
@rkooo567
Copy link
Collaborator Author

rkooo567 commented Aug 6, 2024

cc @ruisearch42

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Aug 6, 2024

Please provide a high level feedback! There's a little more work I need (will be done tmrw)

  • Add unit tests for sequence.py + remove request_ids
  • fix failed CI tests
  • Resolve merge conflict

tests/distributed/test_pipeline_parallel.py Outdated Show resolved Hide resolved
tests/prompts/example.txt Outdated Show resolved Hide resolved
vllm/inputs/registry.py Outdated Show resolved Hide resolved
@@ -18,16 +18,17 @@ class LoRARequest(AdapterRequest):
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""
__metaclass__ = AdapterRequest
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is required because LoraRequest needs to inherit from this ABC class


import torch
from PIL import Image
from transformers import PreTrainedTokenizerBase

from vllm.config import ModelConfig
# from vllm.config import ModelConfig
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found there's circular dependency issue with this. I will either fix it or revert it

vllm/worker/worker.py Outdated Show resolved Hide resolved
@youkaichao
Copy link
Member

Since this feature is mainly used for ray spmd path, I'm okay with it if it cannot deal with logit processors currently, and we will refactor logits processor anyway in the future.

My suggestion is to have a ARRAY_INT32 constant somewhere, and switches between I or L depending on the actual size.

I think this part should be addressed. We should not have a machine-dependent code here. It would be difficult to modify in the future.

We can use some constant like VLLM_TOKEN_ID_ARRAY_TYPE , in case we need to change it to some other types in the future.

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Aug 17, 2024

We can use some constant like VLLM_TOKEN_ID_ARRAY_TYPE , in case we need to change it to some other types in the future.

yeah actually a good point. working on it now (I added typecode == 'l' in sequenceData, but I agree we should do this)

@youkaichao
Copy link
Member

I'd like to see more perf number comparison here, what's the benefit with this pr in spmd case, and what's the gap (if any) with the main branch (non-spmd case).

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Aug 17, 2024

I'd like to see more perf number comparison here, what's the benefit with this pr in spmd case, and what's the gap (if any) with the main branch (non-spmd case).

Yep! I am preparing a doc actually.

https://docs.google.com/document/d/1XSKPna9-seHrYf1Vwaw-puKWW15avk9Myn8yzWbSQ9Q/edit#heading=h.jtbcvct9biud

It is not fully done yet (still need result from spec decoding, latency benchmark, and 2 nodes PP perf). But there are some numbers here already. I am trying to make the full result available by early next week

  • benefits: spec decoding / pp perf improvement. It is because now model is not interrupted by an engine (and spec decoding & pp has more work than regular tp). for spec decoding, spmd also removes the communciation overhead because we don't need to broadcast tokens
  • It is easier to make output processor async after this
  • have the same impact as decoupling server/engine using zeromq. This is especially useful for company like Anyscale who doesn't use openai server directly (or batch inference kind of use case)
  • con: at very large batch + small model + high tp, it seems to have 5~ish% regression in throughput (I believe there's still space for optimization). Latency seems match

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Aug 17, 2024

VLLM_TOKEN_ID_ARRAY_TYPE

addressed

@rkooo567
Copy link
Collaborator Author

I will merge it tonight unless there are more comments!

@njhill
Copy link
Member

njhill commented Aug 18, 2024

@rkooo567 maybe you could wait until tomorrow, I will try to review tonight (sorry it has been on my todo list!)

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Aug 18, 2024

@njhill is it okay if I follow up after merging the PR (if it is mostly nit comments, I will just address by tomorrow and merge it)? It's been almost 3 weeks since I made a PR... and it affects our internal timeline. (also merge conflict is very painful since the PR touches sequence.py)


VLLM_TOKEN_ID_ARRAY_TYPE = "l"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which type do you intend to use? 64 bit or 32 bit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2024-08-17 at 9 51 55 PM

signed long type (which was originally used for array).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, you should use a machine independent type, because pytorch explicitly use 32bit or 64bit. If the type change in some machines, sometimes there will be silent error. cc @comaniac , Cody met this before.

Copy link
Collaborator Author

@rkooo567 rkooo567 Aug 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually this array is not passed into the torch tensor (you can access it only when you output_tokens_array or prompt_tokens_array API), and it is basically the exacatly same behavior (we originally used "l" type array for sequence data https://github.com/vllm-project/vllm/blob/ce143353c622318a9abf113bebee1cfebc274e0f/vllm/sequence.py#L134C14-L134C31) as before this PR. So I think we don't need to worry about "pytorch explicitly use 32bit or 64bit."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

besides for the answer of your question, array("l") is signed long (4 bytes), and it is same for torch I believe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you intend to use 32bit int, you should try both i and l , and see which one is 32 bit (4 bytes).

I just want to make sure, we always use the same bit width integer, no matter which platform we are running on.

Copy link
Collaborator Author

@rkooo567 rkooo567 Aug 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I read their doc a little more, and it looks like Python array doesn't support platform independent type.

The actual representation of values is determined by the machine architecture (strictly speaking, by the C implementation). The actual size can be accessed through the array.itemsize attribute.

In my platform (x86_64), "I" (unsigned) is 4 bytes, and "l" (signed long) is 8 bytes. So, I think "l" (signed long) is equivalent to torch.long.

https://anyscaleteam.slack.com/archives/C06D3FAT2RM/p1722552126851359?thread_ts=1722551878.077229&cid=C06D3FAT2RM

And according to Simon, the token id is unsigned int. So as long as we don't pass token ids directly to torch tensor (which we don't), using "I" (unsigned int) is sufficient. But using other type such as signed long has no problem I believe.

Regardless, I feel like it is not something we should address in this PR because I just kept it to
"signed long", which is same as before. I think if we want to be really safe, we can either use numpy or write code to decide type based on platforms in a follow up.

Copy link
Collaborator Author

@rkooo567 rkooo567 Aug 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

besides, what was the example that has silently failed? I am trying unsigned int array -> torch.long type tensor, and it seems to work at least (or give a proper error);

In [9]: b
Out[9]: array('I', [1])

In [10]: b = array("I", [4294967295])

In [11]: b = array("I", [4294967295 + 1])
---------------------------------------------------------------------------
OverflowError                             Traceback (most recent call last)
Cell In[11], line 1
----> 1 b = array("I", [4294967295 + 1])

OverflowError: unsigned int is greater than maximum

In [12]: b = array("I", [4294967295])

In [13]: torch.tensor(b, dtype=torch.long)
Out[13]: tensor([4294967295])

@njhill
Copy link
Member

njhill commented Aug 18, 2024

@rkooo567 was hoping it should be ok to wait until Sunday rather than Saturday night :)
I was now planning to look early in the morning..

My main question is whether it’s still optional/opt-in and whether it has any negative perf impact at all on any cases where spmd is not enabled?

Where we are making multiple different kinds of optimizations I wonder if it’s better for them to be in separate PRs so that the benefits can be assessed individually. Like the changes here for msgspec, deltas, use of python arrays, etc. It would probably then also be easier to get them merged quicker and have bit less merge conflict pain etc.

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Aug 18, 2024

My main question is whether it’s still optional/opt-in and whether it has any negative perf impact at all on any cases where spmd is not enabled?

All the delta input & spmd related code is opt-in, and completely separated from original code path.

There's no other change other than delta + serialization. array related changes are same as before (just a bit of refactoring for msgspec). Serialization optimization doesn't affect the main code path because it is not serialized for regular path.

Only con I can think of is that now all sequence group metadata inputs are using msgspec.Struct instead of dataclass. Both of them has almost the same semantics, but msgspec.Struct is a little more restrictive (e.g., you cannot have union of the same type). I think performance-wise, I actually msgspec.Struct may be faster because it is a c object technically. For example, I found removing this reallocation optimization has no performance impact after this PR. #7109 (comment)

@rkooo567
Copy link
Collaborator Author

also benchmark result from master; https://buildkite.com/vllm/ci-aws/builds/7074#01916209-fef7-45e9-9b63-34641474c4ea

this PR: https://buildkite.com/vllm/ci-aws/builds/7072

so I think there's almost no change

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rkooo567 ... my comments are mostly nits.

Some more higher level thoughts (not suggesting to change in this PR):

  • In general it feels like data structures used within the scheduler should be decoupled from those used for internal RPCs
  • I guess none of this would be necessary if we do more complete spmd as recently discussed, where the scheduler is deterministic and also included
  • I think as implemented here it shouldn't be hard to exploit for non-spmd cases too - such as non-spmd PP (e.g. the MP impl) and more generally with a persistent batch (see below)
  • Rather than for every step constructing full metadata, computing/sending deltas, reconstructing full metadata from deltas, using this to create full tensors on the GPU, I feel we should aim to keep things stateful further down - the GPU tensors can be persistent/pre-allocated - and then just operate in deltas everywhere

vllm/executor/msgspec_utils.py Outdated Show resolved Hide resolved
vllm/executor/msgspec_utils.py Outdated Show resolved Hide resolved
if not self.ignore_eos:
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)

@cached_property
@property
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why will the sampling type change on the fly now?

If it's a matter of cached_property not being supported with msgspec, maybe we could add a _sampling_type field, set in post-init?

vllm/sequence.py Outdated Show resolved Hide resolved
@rkooo567
Copy link
Collaborator Author

rkooo567 commented Aug 18, 2024

@njhill thanks for the quick review! addressing comments rn!

Some more higher level thoughts (not suggesting to change in this PR):

Yeah I generally agreed with the idea. I think once the basic spmd version is working, it should be straightforward to make it fulll spmd. I think make batching stateful within workers are the right idea.

One diff in my mind is that imo it is still more beneficial to decouple scheduler from workers for more complicated scheduling requirement in the future (starting from pipeline parallelism). And I think the data that needs to be exchanged from scheduler <> worker is pretty small if the batches become stateful. And also it should be totally possible to have low overhead for scheduler <> worker (I think it is already less than 300us for most of cases. 0 if we can make scheduler async).

@rkooo567
Copy link
Collaborator Author

comments all addressed @njhill !

@rkooo567
Copy link
Collaborator Author

okay, I am merging the PR. I will follow up new comments in a follow up PR

@rkooo567 rkooo567 merged commit ff7ec82 into vllm-project:main Aug 19, 2024
64 checks passed
zifeitong pushed a commit to zifeitong/vllm that referenced this pull request Aug 20, 2024
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Aug 22, 2024
omrishiv pushed a commit to omrishiv/vllm that referenced this pull request Aug 26, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants