Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallel support to T5 via NxD #697

Merged
merged 29 commits into from
Oct 24, 2024
Merged

Conversation

JingyaHuang
Copy link
Collaborator

@JingyaHuang JingyaHuang commented Sep 19, 2024

What does this PR do?

Fixes #317
Fixes #479

Add tensor parallel support to large T5 models.

  • Compile parallel T5 with Optimum CLI
  • Small functional example
optimum-cli export neuron --model google-t5/t5-small --tensor_parallel_size 2 --task text2text-generation --batch_size 1 --sequence_length 128 --num_beams 4 t5_neuronx_tp2/
  • Real scenario
optimum-cli export neuron --model google/flan-t5-xl --tensor_parallel_size 8 --task text2text-generation --batch_size 1 --sequence_length 128 --num_beams 4 flan_t5_xl_neuronx_tp8/
  • Compile parallel T5 with Modeling API
from optimum.neuron import NeuronModelForSeq2SeqLM

# 1. compile
if __name__ == "__main__":
    model_id = "google/flan-t5-xl"
    input_shapes = {
        "batch_size": 1,
        "sequence_length": 128,
        "num_beams": 4,
    }
    neuron_model = NeuronModelForSeq2SeqLM.from_pretrained(
        model_id, export=True, tensor_parallel_size=2, dynamic_batch_size=False, **input_shapes
    )
    save_path = "flan_t5_xl_neuronx_tp8/"
    neuron_model.save_pretrained(save_path)
    del neuron_model
  • Inference support for sharded T5
from optimum.neuron import NeuronModelForSeq2SeqLM
from transformers import AutoTokenizer

# 2. inference
neuron_model = NeuronModelForSeq2SeqLM.from_pretrained("flan_t5_xl_neuronx_tp8")
tokenizer = AutoTokenizer.from_pretrained("flan_t5_xl_neuronx_tp8")
prompt = "translate English to German: Lets eat good food."
inputs = tokenizer(prompt, return_tensors="pt")
num_return_sequences = 4

output = neuron_model.generate(
    **inputs,
    num_return_sequences=num_return_sequences,
)
results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]

print("Results:")
for i, summary in enumerate(results):
    print(i + 1, summary)
  • Tests
  • Documentation

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@JingyaHuang JingyaHuang marked this pull request as ready for review October 22, 2024 14:47
Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an awesome first step: I need to investigate a bit more on the LLama side to see if this is actually compatible. The main differences I see for now are that the modeling is explicitly redefined instead of patched (because of several optimized layers/operations), and the export/compilation uses the new ModelBuilder (and I think eventually this will be mandatory).

optimum/commands/export/neuronx.py Show resolved Hide resolved
optimum/commands/export/neuronx.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/base.py Outdated Show resolved Hide resolved
# Start trace
if tp_degree > 1:
# 1. use NxD to trace for parallel
neuron_model = neuronx_distributed.trace.parallel_model_trace(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is ok in a first step, but for LLama example they are not using this anymore, but instead the ModelBuilder class that wraps the model into NxDModel classes that contains several sub-models with different input shapes (bucketing).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the use of bucketing mature and justified? I think we can start with parallel_model_trace anyway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It goes a bit beyond that, because prefill / decode already use two different input shapes, not even mentioning bucketing, and using the builder allows to share the same weights between all the alternate graphs.

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me

optimum/neuron/modeling_seq2seq.py Outdated Show resolved Hide resolved
@JingyaHuang
Copy link
Collaborator Author

Overall looks good to me

@michaelbenayoun Are you ok with the way I used Parallelizer as well?

(I tried to call parallelize instead of the private method but there are to many args/features trainium only that I feel hard to segment, and especially _parallelize includes everything I need to plugin the parallel layers, and except for that I only need to load the ckpt weights before tracing.)

tests/generation/test_generate.py Outdated Show resolved Hide resolved
optimum/commands/export/neuronx.py Show resolved Hide resolved
optimum/commands/export/neuronx.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/__main__.py Show resolved Hide resolved
optimum/exporters/neuron/base.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/utils.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Show resolved Hide resolved
# The KV cache always use pre-allocated memory, no host-device communication overhead.
# Here we iterate sharded encoders and decoders since the encoder on each rank will return cache as device tensors,
# we want to assign them to the cache of the sharded decoder on the same rank to avoid the copy. The KV cache always
# use pre-allocated memory, no host-device communication overhead.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks !

@JingyaHuang JingyaHuang merged commit 1de603e into main Oct 24, 2024
11 checks passed
@JingyaHuang JingyaHuang deleted the add-tp-support-t5 branch October 24, 2024 16:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Flan-UL2 compilation failure Support for T5-11b T5-XXL + TP
4 participants