Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Support PyTorch/XLA FSDP via SPMD #28949

Merged
merged 10 commits into from
Feb 14, 2024
Merged

Conversation

alanwaketan
Copy link
Contributor

@alanwaketan alanwaketan commented Feb 9, 2024

What does this PR do?

Summary:
This is the first attempt to enable FSDP via SPMD (FSDPv2) on PyTorch/XLA model.

More information about FSDPv2 can be found here:

  1. A user guide: https://github.com/pytorch/xla/blob/master/docs/fsdpv2.md
  2. A RFC: [RFC] FSDP via SPMD pytorch/xla#6379

Besides the initial implementation of FSDPv2 in r2.2, this change will also requires the following changes in PyTorch/XLA:

  1. [FSDPv2] Enable auto-wrapping pytorch/xla#6499
  2. [FSDPv2] Use the global mesh API pytorch/xla#6500
  3. [SPMD] Introduce global mesh pytorch/xla#6498
  4. [FSDPv2] Move the module to xla device pytorch/xla#6525
    Therefore, it will only be compatible with the nightly builds.

Example use cases:

  1. Prepare a FSDPv2 config:
{
    "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
    ],
    "xla": true,
    "xla_fsdp_v2": true,
    "xla_fsdp_grad_ckpt": true
}
  1. Invoke the trainer using the following command:
XLA_USE_SPMD=1 XLA_USE_BF16=1 python3 examples/pytorch/language-modeling/run_clm.py     --num_train_epochs 1     --dataset_name wikitext     --dataset_config_name wikitext-2-raw-v1 --per_device_train_batch_size 128     --do_train     --output_dir /tmp/test-clm     --overwrite_output_dir     --config_name ../transformers_pt/2B.config     --cache_dir /tmp     --tokenizer_name hf-internal-testing/llama-tokenizer      --block_size 1024     --optim adafactor     --save_strategy no     --logging_strategy no --fsdp "full_shard" --fsdp_config fsdp_config.json --torch_dtype bfloat16 --dataloader_drop_last yes

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @younesbelkada

@alanwaketan
Copy link
Contributor Author

Can HF folks point me on how to add test case in this case and also how to update the documentation?

@alanwaketan
Copy link
Contributor Author

cc @yeounoh @jonb377

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall! We might want to add a small test, it can be done in a followup PR.
Pinging @muellerzr for a second look!

Comment on lines +173 to +175
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not super fan of super short names but seems common in trainer!

src/transformers/trainer.py Outdated Show resolved Hide resolved
@ArthurZucker
Copy link
Collaborator

Tests should be added in the tests/trainer/test_trainer.py file. You should find similar tests!

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @ArthurZucker hinted at, we now don't handle things like this in the trainer directly. I would rather see this code over in accelerate which we can then bring into Trainer automatically since it relies on it for preparation. Especially as this deals with the dataloaders. Would that be possible please! :)


if self.is_fsdp_xla_v2_enabled:
from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
SpmdFullyShardedDataParallel as FSDPv2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this easier by importing FSDPv2 as FSDP instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask what's the benefits of doing so?

raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
xs.mark_sharding(real_output, mesh, ("fsdp", None, None))

self.model = model = FSDPv2(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And then leave the check for down here on what to do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shard_output is not used by FSDPv1. Shouldn't we guard that with the flag too?

@alanwaketan
Copy link
Contributor Author

As @ArthurZucker hinted at, we now don't handle things like this in the trainer directly. I would rather see this code over in accelerate which we can then bring into Trainer automatically since it relies on it for preparation. Especially as this deals with the dataloaders. Would that be possible please! :)

Can you elaborate it a bit more? I can move the model = model.to(xm.xla_device()) logic. But for the dataloader logic, i.e., tpu_spmd_dataloader, where do you suggest me to move it to?

@alanwaketan
Copy link
Contributor Author

Tests should be added in the tests/trainer/test_trainer.py file. You should find similar tests!

Speaking of adding tests, what should I test? I mean do you have TPU CI?

# PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here.
if is_torch_tpu_available():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed a bug here. cc @ArthurZucker @jonb377

@alanwaketan
Copy link
Contributor Author

The test failures don't seem to be related. I tried rebasing as well.

@alanwaketan
Copy link
Contributor Author

Thanks @ArthurZucker and @muellerzr for approving the change.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@alanwaketan
Copy link
Contributor Author

It's all green. Can HF folks help with landing the PR? Appreciate it.

@amyeroberts
Copy link
Collaborator

I can merge :) Thanks for adding this support @alanwaketan!

@amyeroberts amyeroberts merged commit 5f06053 into huggingface:main Feb 14, 2024
22 checks passed
itazap pushed a commit that referenced this pull request May 14, 2024
* Initial commit

* Add guards for the global mesh

* Address more comments

* Move the dataloader into integrations/tpu.py

* Fix linters

* Make karg more explicitly

* Remove the move device logic

* Fix the CI

* Fix linters

* Re-enable checkpointing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants