Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Core] Support tensor parallel division with remainder of attention heads #5367

Closed
wants to merge 24 commits into from

Conversation

NadavShmayo
Copy link
Contributor

Solved #1041 #5003

This pull request adds support for tensor parallel division with a remainder (only for some models, other models could be changed easily to also support this)- meaning you can divide any number of attention heads in a tensor parallel manner to any number of GPUs.

I for example have a use case for which I want to deploy a 70B model which doesn't really fit on 2XA100-80Gi GPUs, so I want to deploy it using 3XA100-80Gi GPUs, but currently it isn't supported by vLLM.

Weight sharding

I implemented this by loading the remainder of the weights to the first GPUs, for example when sharding a weight matrix of size 4096 across 3 GPUs, we shard it as 1366, 1365, 1365.

Grouped Query Attention sharding

For attention weights with GQA (Grouped-Query-Attention) this is a bit more complex, as we want each group of attention heads to be loaded on the same GPU and include both the query matrices and key-value matrices. In this case we shard in the same manner as before, but ensure the shard size is a multiple of the attention heads per group:
32 Attention heads with 8 Key-Value heads -> (12 Attention heads with 3 Key-Value Heads, 12 Attention heads with 3 Key-Value Heads, 8 Attention heads with 2 Key-Value Heads).
Perhaps this logic could be optimized to avoid some GPUs loading much less heads than others.

Further steps

  1. I removed some assert clauses regarding the case when tensor parallel world size is more than Key-Value heads, still have to make sure this works in this case.
  2. Add support for all models.
  3. Validate unit tests and possibly add some.
  4. Implement advanced GQA sharding logic.

@itaybar
Copy link

itaybar commented Jun 9, 2024

Looks amazing

@AlmogBaku
Copy link

Some builds are failing?...

@mgoin
Copy link
Collaborator

mgoin commented Jun 10, 2024

@NadavShmayo please run ./format.sh to enable the CI to run.

@NadavShmayo
Copy link
Contributor Author

NadavShmayo commented Jun 11, 2024

@NadavShmayo please run ./format.sh to enable the CI to run.

I've run the formatting script before pushing and fixed all the problems, is there something I'm missing regarding this?

EDIT: There seems to be a problem with either the formatting script or the CI, as the formatting scripts formats the imports in the model_executor/layers/linear.py file differently than the CI expects, and when I manually format it as the CI script expects the CI fails in another stage.
After further investigation I found out that ruff and yapf seem to conflict for some reason, which means one of them will fail during CI. Still not sure how to fix it.

Also fixed some of the failing tests for now, will fix the rest of them hopefully later today.

@NadavShmayo
Copy link
Contributor Author

@simon-mo @mgoin
Hey, is there any chance we can merge this to the next release?
I've just pushed a version which should pass all of the tests, I've added comments in one problematic import block for which yapf and isort conflict, so I disabled isort for this block at least for now.

I'll also start working on a separate pull request to improve the performance of this implementation.

Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a TP=3 case to a current distributed test so we can have this actively tested?

tp_size = get_tensor_model_parallel_world_size()
assert self.quant_method is not None
self.output_size_per_partition = divide(self.output_size, tp_size)
self.output_size_per_partition = output_size // tp_size + (
output_size % tp_size > tp_rank)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you separate this into its own variable as the remainder for clarity and/or please add a comment describing what is the intended behavior? The condition makes it a bit unclear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, this should use the util function I added for this logic, which should make it more readable.

@@ -49,7 +51,7 @@
from vllm.sequence import IntermediateTensors, SamplerOutput


@torch.compile
#@torch.compile
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please restore

@mgoin
Copy link
Collaborator

mgoin commented Jul 3, 2024

FYI I think the current complexity in this PR of needing to edit each model definition makes this difficult to accept and commit to. We will need careful consideration to commit to supporting this uneven TP scenario, so I believe getting this into the next release might be optimistic.

@NadavShmayo
Copy link
Contributor Author

NadavShmayo commented Jul 3, 2024

FYI I think the current complexity in this PR of needing to edit each model definition makes this difficult to accept and commit to. We will need careful consideration to commit to supporting this uneven TP scenario, so I believe getting this into the next release might be optimistic.

@mgoin I've fixed the code according to your comments, and added some test cases for this PR, thanks!
(Feel free to ask me to add more test cases in case I missed something).

Regarding your last comment, I understand it does seem like a major change, but wouldn't it be possible to do this gradually?
Currently I've implemented this logic for 3 model architectures, but for the other architectures there is an assert statement in the code that won't allow uneven tensor parallelism.

Anyways for most models the changes required to support this should be minor.

@njhill
Copy link
Member

njhill commented Jul 4, 2024

Thanks @NadavShmayo, I think there may be quite a bit of interest in this being supported. I agree with @mgoin though that we'd want to minimize the invasiveness of the changes as much as possible, and is unlikely to make it into the imminent release.

@NadavShmayo
Copy link
Contributor Author

@njhill
I understand and agree we should be careful when merging this.

Would it be better if I add a list of models that support this feature, and a validation during setup to make sure the model supports this feature if it's used? This way we can avoid changing all the model architectures at once.

Anyways I'll make sure to be attentive to comments on this PR, so it can be merged as soon as possible while being careful not to impair existing vLLM functionality.

@haltingstate
Copy link

haltingstate commented Jul 12, 2024

Frontier models such as Qwen2 cannot be run, until this is merged in.

Many agent frontier models are using prime numbers for the attention head count. The attention head count or vocabulary not divisible by 1,2,4,8. And other frontier models are only divisible by 1 and 20 GPUs.

Frontier sub-7B sized models are also affected by this.

I think the change should be merged in soon or it will delay testing/evaluation for frontier agent model research.

Since not all models can currently be loaded by vLLM, it makes uniform evaluation and performance benchmarking of frontier agent models, impossible.

@robertgshaw2-neuralmagic
Copy link
Collaborator

Frontier models such as Qwen2 cannot be run, until this is merged in.

Many agent frontier models are using prime numbers for the attention head count. The attention head count or vocabulary not divisible by 1,2,4,8. And other frontier models are only divisible by 1 and 20 GPUs.

Frontier sub-7B sized models are also affected by this.

I think the change should be merged in soon or it will delay testing/evaluation for frontier agent model research.

Since not all models can currently be loaded by vLLM, it makes uniform evaluation and performance benchmarking of frontier agent models, impossible.

Which Qwen2 models are not able to be run?

@youkaichao
Copy link
Member

given that this change is too intrusive, I note in the doc , that we should use pipeline parallel in this case.

@youkaichao youkaichao closed this Jul 15, 2024
@NadavShmayo
Copy link
Contributor Author

NadavShmayo commented Jul 17, 2024

@youkaichao I do understand the concerns of how intrusive this change is.
However in most use cases tensor-parallelism seems to give better performance than pipeline-parallelism, especially in terms of latency of a single request (with tensor parallelism I could get 24 tokens/s while with pipeline parallelism I couldn't get more than 11), and I think it would be a shame to give up on that.
Here are some benchmarks I ran comparing the 2 solutions for 3 GPU distribution of LLaMa 3 70B:

Pipeline parallel performance

Results of Pipeline Parallel = 3, request rate inf, 100 requests
============ Serving Benchmark Result ============
Successful requests:                     100
Benchmark duration (s):                  95.01
Total input tokens:                      22398
Total generated tokens:                  17545
Request throughput (req/s):              1.05
Input token throughput (tok/s):          235.75
Output token throughput (tok/s):         184.67
---------------Time to First Token----------------
Mean TTFT (ms):                          4693.10
Median TTFT (ms):                        5123.46
P99 TTFT (ms):                           8589.95
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          206.18
Median TPOT (ms):                        133.36
P99 TPOT (ms):                           1061.37
---------------Inter-token Latency----------------
Mean ITL (ms):                           153.84
Median ITL (ms):                         116.85
P99 ITL (ms):                            431.37
==================================================

Results of Pipeline Parallel = 3, request rate inf, 1 request
============ Serving Benchmark Result ============
Successful requests:                     1
Benchmark duration (s):                  11.47
Total input tokens:                      13
Total generated tokens:                  120
Request throughput (req/s):              0.09
Input token throughput (tok/s):          1.13
Output token throughput (tok/s):         10.46
---------------Time to First Token----------------
Mean TTFT (ms):                          149.52
Median TTFT (ms):                        149.52
P99 TTFT (ms):                           149.52
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          95.10
Median TPOT (ms):                        95.10
P99 TPOT (ms):                           95.10
---------------Inter-token Latency----------------
Mean ITL (ms):                           95.55
Median ITL (ms):                         94.28
P99 ITL (ms):                            109.13
==================================================

Results of Pipeline Parallel = 3, request rate 10, 100 requests
============ Serving Benchmark Result ============
Successful requests:                     100
Benchmark duration (s):                  95.62
Total input tokens:                      22398
Total generated tokens:                  17059
Request throughput (req/s):              1.05
Input token throughput (tok/s):          234.25
Output token throughput (tok/s):         178.41
---------------Time to First Token----------------
Mean TTFT (ms):                          891.28
Median TTFT (ms):                        811.73
P99 TTFT (ms):                           2292.04
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          282.53
Median TPOT (ms):                        141.45
P99 TPOT (ms):                           2035.87
---------------Inter-token Latency----------------
Mean ITL (ms):                           142.17
Median ITL (ms):                         117.60
P99 ITL (ms):                            180.30
==================================================

Tensor parallel performance

Results of Tensor Parallel = 3, request rate inf, 100 requests
============ Serving Benchmark Result ============
Successful requests:                     100
Benchmark duration (s):                  69.74
Total input tokens:                      22398
Total generated tokens:                  17028
Request throughput (req/s):              1.43
Input token throughput (tok/s):          321.14
Output token throughput (tok/s):         244.15
---------------Time to First Token----------------
Mean TTFT (ms):                          7026.26
Median TTFT (ms):                        9655.10
P99 TTFT (ms):                           13218.20
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          237.05
Median TPOT (ms):                        105.48
P99 TPOT (ms):                           1870.65
---------------Inter-token Latency----------------
Mean ITL (ms):                           145.35
Median ITL (ms):                         85.91
P99 ITL (ms):                            257.87
==================================================

Results of Tensor Parallel = 3, request rate inf, 1 request
============ Serving Benchmark Result ============
Successful requests:                     1
Benchmark duration (s):                  5.16
Total input tokens:                      13
Total generated tokens:                  120
Request throughput (req/s):              0.19
Input token throughput (tok/s):          2.52
Output token throughput (tok/s):         23.28
---------------Time to First Token----------------
Mean TTFT (ms):                          81.50
Median TTFT (ms):                        81.50
P99 TTFT (ms):                           81.50
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          42.63
Median TPOT (ms):                        42.63
P99 TPOT (ms):                           42.63
---------------Inter-token Latency----------------
Mean ITL (ms):                           42.95
Median ITL (ms):                         42.62
P99 ITL (ms):                            43.18
==================================================

Results of Tensor Parallel = 3, request rate 10, 100 requests
============ Serving Benchmark Result ============
Successful requests:                     100
Benchmark duration (s):                  70.44
Total input tokens:                      22398
Total generated tokens:                  17020
Request throughput (req/s):              1.42
Input token throughput (tok/s):          317.98
Output token throughput (tok/s):         241.63
---------------Time to First Token----------------
Mean TTFT (ms):                          2574.39
Median TTFT (ms):                        1955.73
P99 TTFT (ms):                           6755.54
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          290.98
Median TPOT (ms):                        120.91
P99 TPOT (ms):                           2697.50
---------------Inter-token Latency----------------
Mean ITL (ms):                           126.39
Median ITL (ms):                         86.07
P99 ITL (ms):                            262.36
==================================================

Alternative suggestion

I suggest the following alternative, which avoids making intrusive changes in the code while still allowing to implement this feature, at the cost of some code duplication:

We could add a flag to explicitly enable this feature (uneven tensor parallel distribution), and when this flag is enabled we initialize a special model class (for example instead of LlamaForCausalLM we initialize LlamaUnevenTPForCausalLM).
We also keep all the linear parallel layers unchanged, and add new classes with the implementation from this PR - these new classes will be used by the LlamaUnevenTPForCausalLM class only.

This way we don't change any implementation of existing models or layers, while still allowing this pull-request for significant performance improvements in this scenario.

What do you think?

@youkaichao
Copy link
Member

It is expected that TP is faster than PP if you can hold the model in a single node. It is more about maintenance burden. We have limited bandwidth, and cannot maintain this feature. We will only support common usecases in the main branch. I have seen thousands of issues with detailed environment reports, and I can confidently say that the 3 GPU usecase is a very rare case.

My opinion is, the vLLM codebase will only support 3 GPU case by PP. Meanwhile, you can maintain another repo to keep this feature.

In fact, your approach of LlamaUnevenTPForCausalLM makes sense. Since these are quite standalone changes and do not touch vLLM's core code, it is totally possible for you to create some wrapper package on top of vLLM to support this feature. And we can note in the documentation, that, if anyone wants TP with 3 GPU, please try to add your plugin/wrapper.

I think that's how open-source works, you cannot expect to add all the features to the main branch. Let's build the ecosystem and the community together.

@sekh77
Copy link

sekh77 commented Aug 31, 2024

@youkaichao, @NadavShmayo -

I definitely have a need for this feature. I'm pretty sure many others will also be having a need for this to be available in vLLM.

I don't see any need to use more GPUs than is necessary to load a given model. For example, if I can load a model in exactly 5 GPUs, why would I need to allocate 8 GPUs to load that model.

Here's my situation and requirements:

  1. I have 3 nodes in Azure with 12 A100 80GB GPUs (4 GPUs per node) connected through an Infiniband.
  2. In my conversational AI chat application, users can dynamically switch between models in the chat screen right at runtime - so they have a choice to use one model over the other depending on how a model performs for their complex queries.
  3. I want to pre-load my GPUs with LLaMA3.1 70B, Mixtral8x22B, and Databricks DBRX so that my users can choose from any of these three models during chat.
  4. The application automatically calculates the model parameters based on information from model's config.json. And then it uses a formula to derive the exact number of GPUs a model will require to load and infer.
  5. Based on this formula, LLaMA3.1 70B requires 3 GPUs, Mixtral8x22B requires 5 GPUs, and Databricks DBRX requires 4 GPUs.
  6. Ideally all three models should fit in 12 GPUs. However, with the current vLLM architecture / the way it calculates, LLaMA3.1 70B will need 4 GPUs (64 attention head is not divisible by 3 but divisible by 4), Mixtral8x22B will need 8 GPUs, and DBRX will need 4 GPUs (no change for DBRX because 4 matches vLLMs expectation).
  7. Now this puts me into a situation where I can load only any two models that would take up 8 GPUs leaving the remaining 4 GPUs unused. This is not a good use compute resources. especially given the fact that these are expensive GPUs.

I use pipeline_parallel_size = 1 and set tensor_parallel_size to be the exact number of GPUs that a model would need to load based on what is mentioned in this vLLM documentation for distributed inference - https://docs.vllm.ai/en/latest/serving/distributed_serving.html

So, anything that can be done to move away from the current constraints of 2,4,8,16 will be highly beneficial for a lot of Enterprises. This is a common feedback that I hear from people using vLLM. Everything else is absolutely great and awesome about vLLM. No doubt whatsoever.

@atyshka
Copy link

atyshka commented Oct 31, 2024

@NadavShmayo Do you still plan to implement this feature via the plugin framework added in #7426?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants