Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tensor Parallel to torch_native_llama #1876

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

kwen2501
Copy link

@kwen2501 kwen2501 commented Nov 2, 2024

Motivation

The torch_native_llama model does not have Tensor Parallel support today. This PR adds it, using torch.distributed APIs.

Modifications

  • Added a .tensor_parallel() utility;
  • Added ColwiseParallel and RowwiseParallel annotations to related sub-modules;

Tests

pytest test/srt/test_torch_tp.py

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

cc: @jerryzh168 @merrymercy @wz337

python/sglang/srt/models/torch_native_llama.py Outdated Show resolved Hide resolved
python/sglang/srt/models/torch_native_llama.py Outdated Show resolved Hide resolved
python/sglang/srt/models/torch_native_llama.py Outdated Show resolved Hide resolved
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I think this is the best we can do for now until we don't use fused qkv and rely on torch.compile for speedup

@kwen2501 kwen2501 marked this pull request as ready for review November 9, 2024 00:38
Copy link
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@merrymercy merrymercy self-assigned this Nov 9, 2024
@kwen2501 kwen2501 changed the title [Draft] Add Tensor Parallel to torch_native_llama Add Tensor Parallel to torch_native_llama Nov 11, 2024
@@ -24,7 +24,10 @@
from torch import nn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a unit test or at least add some doc strings here on how to run this model?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. I added a section describing how to run the model with TP.

@kwen2501
Copy link
Author

@merrymercy Thanks much for your review. Was in roadmapping hence the code change was delayed.
I added a unit test: test/srt/test_torch_tp.py.
It can be triggered by:
pytest test/srt/test_torch_tp.py
and would by default use two GPUs.
It is similar to test/srt/test_bench_latency.py.

Thanks for your rebase and would appreciate your review!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants