-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] Support PyTorch/XLA FSDP via SPMD #28949
Conversation
Can HF folks point me on how to add test case in this case and also how to update the documentation? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall! We might want to add a small test, it can be done in a followup PR.
Pinging @muellerzr for a second look!
import torch_xla.distributed.spmd as xs | ||
import torch_xla.runtime as xr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not super fan of super short names but seems common in trainer!
Tests should be added in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As @ArthurZucker hinted at, we now don't handle things like this in the trainer directly. I would rather see this code over in accelerate which we can then bring into Trainer automatically since it relies on it for preparation. Especially as this deals with the dataloaders. Would that be possible please! :)
|
||
if self.is_fsdp_xla_v2_enabled: | ||
from torch_xla.experimental.spmd_fully_sharded_data_parallel import ( | ||
SpmdFullyShardedDataParallel as FSDPv2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make this easier by importing FSDPv2
as FSDP
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I ask what's the benefits of doing so?
raise ValueError("Something went wrong, the output of the model shouldn't be `None`") | ||
xs.mark_sharding(real_output, mesh, ("fsdp", None, None)) | ||
|
||
self.model = model = FSDPv2( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And then leave the check for down here on what to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shard_output is not used by FSDPv1. Shouldn't we guard that with the flag too?
Can you elaborate it a bit more? I can move the |
Speaking of adding tests, what should I test? I mean do you have TPU CI? |
# PyTorch/XLA relies on the data loader to insert the mark_step for | ||
# each step. Since we are breaking the loop early, we need to manually | ||
# insert the mark_step here. | ||
if is_torch_tpu_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed a bug here. cc @ArthurZucker @jonb377
The test failures don't seem to be related. I tried rebasing as well. |
Thanks @ArthurZucker and @muellerzr for approving the change. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
It's all green. Can HF folks help with landing the PR? Appreciate it. |
I can merge :) Thanks for adding this support @alanwaketan! |
* Initial commit * Add guards for the global mesh * Address more comments * Move the dataloader into integrations/tpu.py * Fix linters * Make karg more explicitly * Remove the move device logic * Fix the CI * Fix linters * Re-enable checkpointing
What does this PR do?
Summary:
This is the first attempt to enable FSDP via SPMD (FSDPv2) on PyTorch/XLA model.
More information about FSDPv2 can be found here:
Besides the initial implementation of FSDPv2 in r2.2, this change will also requires the following changes in PyTorch/XLA:
Therefore, it will only be compatible with the nightly builds.
Example use cases:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @younesbelkada