-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tensor parallel support to T5 via NxD #697
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an awesome first step: I need to investigate a bit more on the LLama side to see if this is actually compatible. The main differences I see for now are that the modeling is explicitly redefined instead of patched (because of several optimized layers/operations), and the export/compilation uses the new ModelBuilder (and I think eventually this will be mandatory).
# Start trace | ||
if tp_degree > 1: | ||
# 1. use NxD to trace for parallel | ||
neuron_model = neuronx_distributed.trace.parallel_model_trace( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is ok in a first step, but for LLama example they are not using this anymore, but instead the ModelBuilder
class that wraps the model into NxDModel
classes that contains several sub-models with different input shapes (bucketing).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the use of bucketing mature and justified? I think we can start with parallel_model_trace
anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It goes a bit beyond that, because prefill / decode already use two different input shapes, not even mentioning bucketing, and using the builder allows to share the same weights between all the alternate graphs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me
@michaelbenayoun Are you ok with the way I used Parallelizer as well? (I tried to call |
# The KV cache always use pre-allocated memory, no host-device communication overhead. | ||
# Here we iterate sharded encoders and decoders since the encoder on each rank will return cache as device tensors, | ||
# we want to assign them to the cache of the sharded decoder on the same rank to avoid the copy. The KV cache always | ||
# use pre-allocated memory, no host-device communication overhead. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks !
What does this PR do?
Fixes #317
Fixes #479
Add tensor parallel support to large T5 models.
optimum-cli export neuron --model google-t5/t5-small --tensor_parallel_size 2 --task text2text-generation --batch_size 1 --sequence_length 128 --num_beams 4 t5_neuronx_tp2/
optimum-cli export neuron --model google/flan-t5-xl --tensor_parallel_size 8 --task text2text-generation --batch_size 1 --sequence_length 128 --num_beams 4 flan_t5_xl_neuronx_tp8/
Before submitting