-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Pipeline Parallel Support for DeepSeek v2 #6519
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Full CI run is still required to merge this PR so once the PR is ready to go, please make sure to run it. If you need all test signals in between PR commits, you can trigger full CI as well. To run full CI, you can do one of these:
🚀 |
can you test the correctness locally, using https://github.com/vllm-project/vllm/blob/main/tests/distributed/test_pipeline_parallel.py ? |
Sure. I edited the file to set the model to
|
568c1d9
to
2522798
Compare
Rebased to resolve conflict from |
self.start_layer, self.end_layer, self.layers = make_layers( | ||
config.num_hidden_layers, | ||
# layer_idx is still an argument | ||
functools.partial(DeepseekV2DecoderLayer, | ||
config, | ||
cache_config=cache_config, | ||
quant_config=quant_config), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this lambda function will have prefix=
shortly after #6515 .
2522798
to
c83350f
Compare
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
c83350f
to
f22cb28
Compare
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for addressing my comments! please test the correctness locally.
I ran the updated
|
Thanks, that might be caused by the flakiness of pp tests. I'll merge as this PR looks good to me now. Thanks for your contribution! |
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Signed-off-by: Alvant <alvasian@yandex.ru>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Adds pipeline parallel support for DeepSeek v2.
Tested with https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct using
--tensor-parallel-size 1 --pipeline-parallel-size 2