Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Tensor Parallel implementation with PyTorch TP #34184

Merged
merged 26 commits into from
Nov 18, 2024

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Oct 15, 2024

What does this PR do?

This PR uses the torch.distributed.tensor.parallel subpackage to implement Tensor Parallel for Llama (as an example).

The motivation is multi-fold:

  1. to make modeling code simple as single-worker case:
    all manual TP implementations under if self.config.pretraining_tp > 1 can be removed.

  2. to make tensor parallelism easily accessible by users:
    added a model.tensor_parallel(device_mesh) method that allows users to turn a single-proc model into a parallel model. !- Please guide me to a right place to put this function/method if PreTrainedModel is not a preferred place. -!

Note:

This is just a demo. The removal of if self.config.pretraining_tp > 1 may break workflows elsewhere, but it is just intended to calculate how much code can be saved, and hopefully it would be possible to direct it to the mechanism in this PR.

User script:

import os
import torch
import torch.distributed as dist

from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

def main(rank, world_size) -> None:
    device = torch.device(f"cuda:{rank}")
    dist.init_process_group("nccl", device_id=device)

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        tp_plan="auto",
    )

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    prompt = "Can I help"

    inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    outputs = model(inputs)

    next_token_logits = outputs[0][:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1)
    response = tokenizer.batch_decode(next_token)
    print(f"{rank=}: {response}")


if __name__ == "__main__":
    main(int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]))

Launch command:
torchrun --standalone --nproc-per-node 4 tp_hf.py

Output:

Can I help
rank=3: [' you']
rank=1: [' you']
rank=2: [' you']
rank=0: [' you']

Test

CUDA_VISIBLE_DEVICES=0,1 pytest -sv tests/tp/test_tp.py

Doc

perf_infer_gpu_multi.md

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Models:

Library:

Integrations:

HF projects:

Cc PyTorch Distributed Team:
@gnadathur @wanchaol @fduwjj @wz337 @wconstab @tianyu-l

@kwen2501
Copy link
Contributor Author

kwen2501 commented Oct 15, 2024

Cc: @lessw2020 @HamidShojanazeri
Cc: @jerryzh168

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really really nice! I think the only thing I would change is to have the tp_plan loadable and potentially defined in the config. Given that it does not hold any weights, it would allow other frameworks to rely on it as well! 🤗

@ArthurZucker
Copy link
Collaborator

Also we would place

    device = torch.device(f"cuda:{rank}")
    dist.init_process_group("nccl", device_id=device)
    device_mesh = dist.init_device_mesh("cuda", (world_size,))

    with device:
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
        )

potentially inside from_pretrained, this way we automatically set things for users, with them only having to change the way they call their script (with torchrun or not)

@ArthurZucker
Copy link
Collaborator

WDYT?

@ArthurZucker
Copy link
Collaborator

We will also add this to the configuration_llama.py:

DEFAULT_TP_PLAN = {
    "model.layer.*.q_proj":"column_parallel"
}

this way we can init lama config with it!

@ArthurZucker
Copy link
Collaborator

cc @SunMarc as per our discussion!

@kwen2501
Copy link
Contributor Author

kwen2501 commented Oct 23, 2024

Thanks @ArthurZucker @SunMarc for your review.

Good idea to move tp_plan to configuration_llama.py. I made that change in the current version. Indeed modeling_llama.py becomes cleaner now, and it will not have dependency on torch.distributed....

Since the top-level model can have different base_model_prefix ("model", "transformer", etc), I could not directly put FQNs like "model.layers.*" in configuration_llama.py. So I prepare a _base_model_tp_plan starting with "layers.*" facing the base model LlamaModel. This works because model.tensor_parallel() searches for and applies _tp_plan recursively. That way we can serve all top models regardless of what prefix name they give to LlamaModel, and we can also serve a bare-metal LlamaModel.

I also made the TP styles "type neutral" in the config as you suggested. Translation to torch TP types is done only when applying TP.

@kwen2501
Copy link
Contributor Author

kwen2501 commented Oct 23, 2024

Re moving distributed / device_mesh init code inside from_pretrained:

It is a good direction in making the workflow even simpler for users. As in, wouldn't it be nice if the model is distributed after it is created? I think a UI design would be helpful as to how users express the intention of "please create the model in a distributed manner".

For example, would from_pretrained accepts a related argument, like:

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_mesh=device_mesh,
)

If that kwarg is given, I think the intention is clear enough, and we can call tensor_parallel() for the user. Or even better, initialize partial weights only -- this is supported by the proposed torch TP too, but it will require slightly more work in integration with the weight loading part of from_pretained. I can try making a demo in a next PR if you are interested.

Re launcher: I might slightly prefer the above "explicit expression" to "implicit dependency" on torchrun, for keeping HF transformers "launcher neutral", as some users may use SLURM, mp.spawn, etc. Another reason was that init_process_group and init_device_mesh can take some kwargs (such as timeout, options, etc), which might be a bit big for from_pretrained to absorb if it wants to call them inside.

Regardless, I agree that it would be very nice if we can create distributed models out of `from_pretrained" directly, and I am really excite to help that way :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks a lot better!
I think we want to have some small tests to make sure we properly serialize the TP-Plan, and also properly dispatch the model (slow tests, we can probably add them ourselves!)

  • would need some documentation as well, we can do that in a follow up PR for sure

On of the most important question now is: What's the approach regarding ForCausalLM.

This might be for @SunMarc as it's transformers internal:

_base_model_tp_plan will be consumed by the PreTrainedModel ➡ this way we can enforce some checks:

  • check that all models that inherit from PreTrainedModel have a TP plan
  • check nesting (ForCausalLM has a module that has a TP Plan, if it has other module they should have a TP plan

Comment on lines 336 to 354
def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we translate them into torch.distributed tensor-parallel
types.
"""
if not isinstance(style, str):
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str"
)

if style == "colwise":
return ColwiseParallel()
elif style == "rowwise":
return RowwiseParallel()
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
else:
raise ValueError(f"Unsupported parallel style value: {style}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A mapping from tp style to the correct function might be better.
We are also thinking of defining a TensorParallelConfig, your feedback is welcome here, as I don't know the variety of classes / args that might be used!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment! Indeed a mapping style would look better.
The only caveat is that the returned value here is an object rather than a constant or class (see the () behind ColwiseParallel). If we prepare a map, we may be always returning the same object -- the parallelize_module API may be able to support it I guess, I am just not sure if there is a contract guaranteeing that today.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pitching in :)

We should be able to use the same object since it applies required parallel operation to the module and returns a new copy - https://github.com/pytorch/pytorch/blob/86d4b7d60b264cae5a04a1b20719bcd7a5752a4c/torch/distributed/tensor/parallel/api.py#L95

Have also tested it empirically while benchmarking (#34194)

Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SOunds good!

@@ -141,6 +141,16 @@ class LlamaConfig(PretrainedConfig):

model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `LlamaModel`
_base_model_tp_plan = {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to allow this for external use by removing _ so that we can allow users to define tp plan tweaks from config.json?

Given that, shall we as well allow for providing custom tp plan as input to LlamaConfig() that overrides the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea. We can make this public once we prove things work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, base_model_tp_plan should be supported as input to the PreTrainedConfig!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this variable base_model_tp_plan has to be added to PreTrainedConfig

class PretrainedConfig(PushToHubMixin):
with a default value as an empty dict {} which I believe is best possible default for any config sub class inheriting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kmehant @ArthurZucker for the suggestion. I moved base_model_tp_plan to PretrainedConfig in the latest commit.

@kwen2501
Copy link
Contributor Author

Summarizing my discussion with @kmehant re collaboration plan between this PR and #34194:

  • Both PRs agree on the interest in having PyTorch TP in transformers :)
  • We agree that the integration can hopefully happen at a broad base so that it can scale to a set of models, thus we agree on picking PretrainedModel as an integration point.
  • #34194 has changes on the trainer side. And there is a corresponding PR for changes on Accelerate side. So those two PRs focus on workflow-level integration. So hopefully this PR can serve as a base for those two PRs and once landed, unblock the other two PRs.

@ArthurZucker
Copy link
Collaborator

Sounds good, let's focus on this one first then!

@kwen2501
Copy link
Contributor Author

Dumb question: is there a way to bypass the check_repository_consistency test? Thanks! @ArthurZucker @SunMarc

@kwen2501
Copy link
Contributor Author

Okay, I just temporarily disabled the "Copied from" comments in some derivative models. Will add them back once we extend TP plans to those models.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! So basically places where you are disabling it , you just need to do something like:
# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe,lm_head->new_name for example

and before anything run make fix-copies which would apply your changes!

@kwen2501
Copy link
Contributor Author

Thanks @ArthurZucker . I tried that method, but it doesn't seem to work.
Specifically, the test is checking for consistency between modeling_gemma.py and modular_gemma.py. I edited neither, but the CI complains that they are not consistent. Here is the link.

@ArthurZucker
Copy link
Collaborator

Ah that is because you probably modified modeling directly, when these two have a modular file, which has the source of truth!

@kwen2501
Copy link
Contributor Author

kwen2501 commented Nov 15, 2024

@ArthurZucker Thanks again for the suggestions.

  1. A test has been added at tests/tp/test_tp.py, run command:
CUDA_VISIBLE_DEVICES=0,1 pytest -sv tests/tp/test_tp.py
  1. A new doc has been added: perf_infer_gpu_multi.md, titled "Multi-GPU inference".
    A use example and the benchmark figure have been added to the doc.
    @ArthurZucker can you please help merge this image upload for doc? Thanks!
  2. @kmehant The .has_tp_plan property has been added.

@ArthurZucker
Copy link
Collaborator

Merged! let me review one last time 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! 🔥 some small things left to do but super close to merge!

docs/source/en/perf_infer_gpu_multi.md Outdated Show resolved Hide resolved
docs/source/en/performance.md Show resolved Hide resolved
@@ -1442,6 +1449,9 @@ def post_init(self):
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
# If current model is a base model, attach `base_model_tp_plan` from config
if self.base_model is self:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but self._tp_plan should be None only for the base model no?
We could maybe add something like if not self.base_model is self and self._tp_plan is None and self.supports_tp_plan raise an error in the futur, to force people to add the TP plan.

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
tests/tp/test_tp.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, would be nice to have a generic way for all models that have support_tp_plan=True instead of just llama, setting everyone up for success in adding support for other models!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, both me and @kmehant are interested in extending this mechanism to other models. We can think of an way to "copy" the tp_plan to other models with similar architectures (just like how HF copies modeling code between similar models). Maybe covering all AutoModelForCausalLM would be a good next step. Follow-up PR will come soon :)

Copy link
Contributor Author

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ArthurZucker for the close look! I fixed the places you pointed out.

docs/source/en/performance.md Show resolved Hide resolved
tests/tp/test_tp.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
docs/source/en/perf_infer_gpu_multi.md Outdated Show resolved Hide resolved
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, both me and @kmehant are interested in extending this mechanism to other models. We can think of an way to "copy" the tp_plan to other models with similar architectures (just like how HF copies modeling code between similar models). Maybe covering all AutoModelForCausalLM would be a good next step. Follow-up PR will come soon :)

@kwen2501
Copy link
Contributor Author

@ArthurZucker CI is all green now. (Previous failure seems to come from infra instability). Is there anything else I should add before merge? Thanks!

@ArthurZucker
Copy link
Collaborator

Nope merging! 🤗

@ArthurZucker ArthurZucker merged commit 20142ab into huggingface:main Nov 18, 2024
22 checks passed
@ArthurZucker
Copy link
Collaborator

I wanted to check if torch._C.get_device_accelerator had been available for a while or not!

@ArthurZucker
Copy link
Collaborator

SUper kudos for the PR! 🚀

@loadams
Copy link
Contributor

loadams commented Nov 18, 2024

Hi @ArthurZucker and @kwen2501 - it seems this PR now requires that users have torch 2.5+, is that correct? When trying to run the DeepSpeed unit tests with the latest transformers and torch 2.3 I'm getting errors around not being able to import from torch.distributed.tensor:

from torch.distributed.tensor import Replicate

The specific error:

E   RuntimeError: Failed to import transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder because of the following error (look up to see its traceback):
E   Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
E   cannot import name 'Replicate' from 'torch.distributed.tensor' (/usr/local/lib/python3.10/dist-packages/torch/distributed/tensor/__init__.py)

Opened #34795 to track if that's better visibility wise.

@kmehant
Copy link

kmehant commented Nov 19, 2024

@ArthurZucker @kwen2501 Planning for a HF blog? I am in!

@kwen2501
Copy link
Contributor Author

@loadams Thanks for reporting it and sorry about the break.
I think that's because Replicate was previously in torch.distributed._tensor.
Let me open a PR to build backward compat support.
Cc: @wz337

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Nov 25, 2024

Sure! Won't have bandwidth to start a blog but feel free to do it! 🤗

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…34184)

* Simplify Tensor Parallel implementation with PyTorch TP

* Move tp_plan to config

* Lint

* Format and warning

* Disable copy-from check

* Conditionally get attr from config

* make fix-copies

* Move base_model_tp_plan to PretrainedConfig

* Move TP into from_pretrained

* Add device context for load

* Do not serialize

* Move _tp_plan setting to post_init

* Add has_tp_plan

* Add test_tp

* Add 'Multi-gpu inference' doc

* Add backward support for device type identification

* Auto-detect accelerator

* supports_tp_plan

* copyright year

* Fix copy
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
…34184)

* Simplify Tensor Parallel implementation with PyTorch TP

* Move tp_plan to config

* Lint

* Format and warning

* Disable copy-from check

* Conditionally get attr from config

* make fix-copies

* Move base_model_tp_plan to PretrainedConfig

* Move TP into from_pretrained

* Add device context for load

* Do not serialize

* Move _tp_plan setting to post_init

* Add has_tp_plan

* Add test_tp

* Add 'Multi-gpu inference' doc

* Add backward support for device type identification

* Auto-detect accelerator

* supports_tp_plan

* copyright year

* Fix copy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants