Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] initial integration #8949

Closed
wants to merge 44 commits into from

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Sep 29, 2024

TODOs (can be future PRs):

  • support embedding model, encoder-decoder model, multi-modality model
  • support attention backend other than flash attention
  • support models other than llama
  • support TP
  • support PP
  • test and integrate lora and quantization
  • perf testing
  • profile and investigate compilation time reduction

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member Author

youkaichao commented Sep 29, 2024

simple test on H100:

throughput:

$ # main branch
$ python benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Meta-Llama-3-8B
Throughput: 28.99 requests/s, 14843.59 tokens/s

$ # this branch
$ python benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Meta-Llama-3-8B
Throughput: 28.89 requests/s, 14792.03 tokens/s

$ # this branch
$ VLLM_TORCH_COMPILE_LEVEL=2 python benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Meta-Llama-3-8B
Throughput: 29.90 requests/s, 15309.14 tokens/s

about 3.5% throughput improvement

single request serving (Output token throughput (tok/s)):

Multi-step Torch Compile Level 0 ( no compilation) Torch Compile Level 2 Torch Compile Level 3
1 114.32 115.61 (+1.1%) 116.92 (+2.3%)
8 119.37 120.39 (+0.8%) 122.15 (+2.3%)
16 119.82 N/A N/A

@youkaichao
Copy link
Member Author

youkaichao commented Sep 30, 2024

pipeline parallel

when I enable pipeline parallel, there's a dynamo error:

[rank0]:     var = tx.output.side_effects.track_object_new(
[rank0]:   File "/data/youkaichao/miniconda/envs/vllm/lib/python3.9/site-packages/torch/_dynamo/side_effects.py", line 243, in track_object_new
[rank0]:     obj = object_new(user_cls)
[rank0]: torch._dynamo.exc.InternalTorchDynamoError: object.__new__(IntermediateTensors) is not safe, use IntermediateTensors.__new__()

[rank0]: from user code:
[rank0]:    File "/data/youkaichao/vllm/vllm/model_executor/models/llama.py", line 450, in forward
[rank0]:     model_output = self.model(input_ids, positions, kv_caches,
[rank0]:   File "/data/youkaichao/miniconda/envs/vllm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/data/youkaichao/vllm/vllm/model_executor/models/llama.py", line 339, in forward
[rank0]:     return IntermediateTensors({

cc @anijain2305

it turns out to be caused by msgpack:

class IntermediateTensors(

when I change it to normal dataclass , it works.

tensor parallel

when I enable tensor parallel, it runs but the output is wrong. I'm still investigating.

anijain2305 added a commit to pytorch/pytorch that referenced this pull request Sep 30, 2024
anijain2305 added a commit to pytorch/pytorch that referenced this pull request Sep 30, 2024
Seen in vllm-project/vllm#8949

ghstack-source-id: 9772ad284d8cbe809147943d2f39da701cd85686
Pull Request resolved: #137044
anijain2305 added a commit to pytorch/pytorch that referenced this pull request Sep 30, 2024
Seen in vllm-project/vllm#8949

ghstack-source-id: 785c59a2b4c04c5bab91eefc0fbb25f946dbe96d
Pull Request resolved: #137044
anijain2305 added a commit to pytorch/pytorch that referenced this pull request Sep 30, 2024
…taclass has untouched __new__"

Seen in vllm-project/vllm#8949

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
anijain2305 added a commit to pytorch/pytorch that referenced this pull request Sep 30, 2024
…uched __new__"

Seen in vllm-project/vllm#8949

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
anijain2305 added a commit to pytorch/pytorch that referenced this pull request Oct 1, 2024
…taclass has untouched __new__"

Seen in vllm-project/vllm#8949

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
anijain2305 added a commit to pytorch/pytorch that referenced this pull request Oct 1, 2024
…uched __new__"

Seen in vllm-project/vllm#8949

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
anijain2305 added a commit to pytorch/pytorch that referenced this pull request Oct 1, 2024
Seen in vllm-project/vllm#8949

ghstack-source-id: 70445f10233dadf74abe0e31b3151a374bea711b
Pull Request resolved: #137044
@youkaichao
Copy link
Member Author

close as it has been moved to #9058

@youkaichao youkaichao closed this Oct 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant