Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Fully Sharded Data Parallel (FSDP) in PyTorch XLA #3431

Merged
merged 20 commits into from
May 9, 2022

Conversation

ronghanghu
Copy link
Collaborator

@ronghanghu ronghanghu commented Mar 18, 2022

Fully Sharded Data Parallel (FSDP) in PyTorch XLA

This PR implements Fully Sharded Data Parallel (FSDP) in PyTorch XLA for sharding Module parameters across data-parallel workers.

Example usage:

from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = model.to(xm.xla_device())
model = FSDP(my_module)

optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters.

Notes:

  • The XlaFullyShardedDataParallel class supports both the ZeRO-2 optimizer (sharding gradients and optimizer states) and the ZeRO-3 optimizer (sharding parameters, gradients, and optimizer states) in https://arxiv.org/abs/1910.02054.
    • The ZeRO-3 optimizer should be implemented via nested FSDP with reshard_after_forward=True. See test/test_train_mp_mnist_fsdp_with_ckpt.py and test/test_train_mp_imagenet_fsdp.py for an example.
    • For large models that cannot fit into a single TPU memory or the host CPU memory, one should interleave submodule construction with inner FSDP wrapping. See FSDPViTModel for an example.
  • a simple wrapper checkpoint_module is provided (based on torch_xla.utils.checkpoint.checkpoint from Add Checkpoint api to use optimization barrier #3524) to perform gradient checkpointing over a given nn.Module instance. See test/test_train_mp_mnist_fsdp_with_ckpt.py and test/test_train_mp_imagenet_fsdp.py for an example.
  • When stepping the optimizer, directly call optimizer.step and do not call xm.optimizer_step. The latter reduces the gradient across ranks, which is not needed for FSDP (where the parameters are already sharded).
  • When saving model and optimizer checkpoints during training, each training process needs to save its own checkpoint of the (sharded) model and optimizer state dicts (use master_only=False and set different paths for each rank in xm.save). When resuming, it needs to load the checkpoint for the corresponding rank.
  • Please also save model.get_shard_metadata() along with model.state_dict() as follows and use consolidate_sharded_model_checkpoints to stitch the sharded model checkpoints together into a full model state dict. See test/test_train_mp_mnist_fsdp_with_ckpt.py for an example.
ckpt = {
    'model': model.state_dict(),
    'shard_metadata': model.get_shard_metadata(),
    'optimizer': optimizer.state_dict(),
}
ckpt_path = f'/tmp/rank-{xm.get_ordinal()}-of-{xm.xrt_world_size()}.pth'
xm.save(ckpt, ckpt_path, master_only=False)
  • The checkpoint consolidation script can also be launched from the command line as follows.
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /path/to/your_sharded_checkpoint_files \
  --ckpt_suffix "_rank-*-of-*.pth"

The implementation of this class is largely inspired by and mostly follows the structure of fairscale.nn.FullyShardedDataParallel in https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html. One of the biggest difference from fairscale.nn.FullyShardedDataParallel is that in XLA we don't have explicit parameter storage, so here we resort to a different approach to free full parameters for ZeRO-3.


Example training scripts on MNIST and ImageNet

  • MNIST: test/test_train_mp_mnist_fsdp_with_ckpt.py (it also tests checkpoint consolidation)
  • ImageNet: test/test_train_mp_imagenet_fsdp.py

Installation

Update (05/05/2022): This XLA FSDP PR requires the 20220505 nightly XLA wheels and 20220413 libtpu. For TPU VMs, start from tpu-vm-pt-1.10 runtime and install the nightly wheels via

# torch, torchvision and torch_xla 20220505
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220505-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220505-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220505-cp38-cp38-linux_x86_64.whl

# libtpu 20220413
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl

# install the FSDP PR into PyTorch/XLA
cd ~ && rm -rf xla_fsdp_dev && git clone https://github.com/ronghanghu/xla.git xla_fsdp_dev
cd xla_fsdp_dev && git checkout xla_fsdp_rebased
sudo rm -rf /usr/local/lib/python3.8/dist-packages/torch_xla/distributed/fsdp
sudo cp -r ./torch_xla/distributed/fsdp /usr/local/lib/python3.8/dist-packages/torch_xla/distributed/
cd ~

Train MNIST on v3-8 TPU

It gets around 98.9 accuracy for 2 epochs:

python3 -u ~/xla_fsdp_dev/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --use_gradient_checkpointing

This script automatically tests checkpoint consolidation at the end. You can also manually consolidate the sharded checkpoints via

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"

Train ImageNet with ResNet-50 on v3-8 TPU

It gets around 75.9 accuracy for 100 epochs; download ImageNet-1k to /datasets/imagenet-1k:

python3 -u ~/xla_fsdp_dev/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
  --use_nested_fsdp

You can also add --use_gradient_checkpointing (which needs to be used along with --use_nested_fsdp) to apply gradient checkpointing on the residual blocks.


Example training scripts on TPU pod (ViT with 10 billion parameters)

To train large models that cannot fit into a single TPU, one should use nested FSDP (wrapping sub-modules with inner FSDP when building the entire model) to implement the ZeRO-3 algorithm.

Please see https://github.com/ronghanghu/vit_10b_fsdp_example for an example of sharded training of a Vision Transformer (ViT) model using this XLA FSDP PR.


TODOs

  • Format the code and resolve lint errors.
  • Test integration with gradient checkpointing.
    • Add a simple wrapper to apply gradient checkpointing to a module.
  • Test scaling to a 10B+ parameter ViT model on TPU v3 with this FSDP implementation.
    • Add an example script to use this XLA FSDP PR to train a 10B+ parameter ViT model.
  • Clean up and address comments from reviewers.
  • Run full suite tests on v3-8 and v3-128 before merging
  • [ ] Speed-memory tradeoff analysis and profiling (partially blocked by XLA profiler cannot capture TPU device trace when running on a pod #3446).

Related issues affecting this PR: pytorch/pytorch#74424, #3330, #3392, #3423, #3441, #3446, #3453, #3455, #3502, #3506, #3510, #3545

@ronghanghu ronghanghu force-pushed the xla_fsdp_rebased branch 13 times, most recently from 7922759 to f15140b Compare March 20, 2022 07:13
@ronghanghu ronghanghu marked this pull request as ready for review March 21, 2022 12:34
@yeounoh yeounoh self-requested a review March 22, 2022 16:23
@miladm miladm self-requested a review March 22, 2022 16:33
@miladm miladm requested a review from mrshenli March 31, 2022 16:17
Copy link
Collaborator

@hjm-aws hjm-aws left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome PR. Can't wait to use it!

@hjm-aws hjm-aws self-requested a review March 31, 2022 18:20
@miladm
Copy link
Collaborator

miladm commented Apr 6, 2022

I successfully ran FSDP to train ResNet on TPUv3.
Reviewing the code.

@hjm-aws
Copy link
Collaborator

hjm-aws commented Apr 7, 2022

@ronghanghu Hi Ronghang, just to confirm I read the code correctly: we don't have prefetch logic here, right? By prefetching I mean pre-all-gather the params for the next layer/submodule.

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 7, 2022

@ronghanghu Hi Ronghang, just to confirm I read the code correctly: we don't have prefetch logic here, right? By prefetching I mean pre-all-gather the params for the next layer/submodule.

@hjm-aws You are right, we don't have a prefetch logic here. The all-gather of the next layer happens after the previous layer is completed (instead of happening in parallel). This is consistent with FairScale's FSDP implementation (and also saves more memory compared to parallel prefetch).

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 12, 2022

As a temporary workaround to #3455 (comment), I added a mark_step_on_freeing parameter in the __init__. If it is True, we call xm.mark_step immediately upon freeing full parameters.

This is a makeshift and inefficient workaround to avoid XLA compiler fusion that breaks parameter freeing in nested FSDP (ZeRO-3). It is useful only when reshard_after_forward is True.

This workaround notably increases the execution time, triggers more compilation, and does not cover all the cases, so we need a permanent solution to #3455 (comment). I'll remove the mark_step_on_freeing option when #3455 is fixed.

(It seems that the CI build failed due to issues unrelated to this PR; this PR does not touch any C++ codebase.)

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented May 6, 2022

(rebased against master branch)

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented May 7, 2022

Update: our internal scaling tests show that we can scale up to transformer models with 60 billion parameters on v3-512 using this FSDP PR. Then we encounter the issue of #3453 and cannot scale it to even larger model sizes.

Computation requires more parameters (4748) than supported (limit 3304).

Before there is a long-term fix for #3453, we may want to try some short-term workarounds (e.g. inserting xm.mark_step between layers in our model) and see whether we can get around it if/when we need to train e.g. GPT-3 scale model with 175B parameters on TPU with PyTorch XLA. But I'm satisfied with the current scaling behavior of this PR, given that 60B is already good for most of our use cases under the default setting without extra xm.mark_step.

@JackCaoG
Copy link
Collaborator

JackCaoG commented May 9, 2022

This is exciting! Thanks @ronghanghu !

@miladm
Copy link
Collaborator

miladm commented May 9, 2022

This is awesome @ronghanghu! Thank you!
Do you plan to submit a separate ViT PR with your changes?

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented May 9, 2022

This is awesome @ronghanghu! Thank you! Do you plan to submit a separate ViT PR with your changes?

@miladm I think this PR should be ready for merge in its current shape once our internal accuracy tests pass. So far the loss curves of these accuracy tests look good but I'm still waiting to confirm their final results (should know by tomorrow). The current PR commits above are exactly what we used internally at Meta AI in our scaling runs (including our ViT and other experiments).

Regarding the ViT scaling example, I would like to keep it as a separate repository in https://github.com/ronghanghu/vit_10b_fsdp_example for now because 1) it involves many open research questions on how to best train ViTs at this 10~60 billion parameter size (such as what optimizer to use, what learning rate to use, how to avoid training divergence) where we don't have an optimal recipe like the ImageNet or MNIST examples above yet, 2) it's more like an example of downstream application rather than features of PyTorch XLA, and 3) it relies on external dependencies such as timm to build the models.

So I think it's better to have a separate repository for scaling examples on top of PyTorch XLA as opposed to having them as part of PyTorch XLA itself. This is similar to keeping a separate T5X repository on top of JAX.

@miladm
Copy link
Collaborator

miladm commented May 9, 2022

Thanks @ronghanghu. I agree, it should be on a separate repo. Thanks for sharing the reference to it on this PR.

@miladm
Copy link
Collaborator

miladm commented May 9, 2022

@ronghanghu Feel free to merge the PR.

Copy link
Collaborator

@miladm miladm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Merging the PR.

@miladm miladm merged commit 3c83269 into pytorch:master May 9, 2022
@hjm-aws
Copy link
Collaborator

hjm-aws commented May 9, 2022

@ronghanghu Hi Ronghang, we (AWS) plan to add ZeRO-2 optimization into your FSDP implementation. ZeRO-2 is just ZeRO-3 minus parameter sharding. We plan to add a flag to the constructor, and when the flag is True, we'll just skip the parameter sharding during init and skip all-gather before forward/backward. What do you think?

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented May 9, 2022

Hi @hjm-aws, you can already do ZeRO-2 by wrapping the entire model with a single (instead of nested) FSDP class with reshard_after_forward=False, and then applying any PyTorch optimizers on the FSDP model's .parameters(). In this way, the full parameters will be all-gathered for the entire model before the forward pass and the gradients will be distributed via reduce-scatter.

The FAIRScale library originally introduced an older SDP class fairscale.nn.ShardedDataParallel for ZeRO-2 as in https://fairscale.readthedocs.io/en/latest/deep_dive/oss_sdp_fsdp.html. But then it has been sort of obsolete as the newer FSDP class fairscale.nn.FullyShardedDataParallel can implement both ZeRO-2 (single plain FSDP + reshard_after_forward=False) and ZeRO-3 (nested FSDP + reshard_after_forward=True).

It is actually the recommended way to do ZeRO-2 using the newer fairscale.nn.FullyShardedDataParallel class, and is equivalent to but more convenient than putting the ZeRO-2's all-gather/broadcasting op in a separate sharded_optimizer class (that handles sharded gradient update and full parameter all-gathering/broadcasting as in the alternative implementation of fairscale.nn.ShardedDataParallel). With FSDP, you can directly use those original PyTorch optimizers such as SGD or Adam.


If one really needs the alternative implementation of ZeRO-2 similar to fairscale.nn.ShardedDataParallel (SDP), then I believe it would be much better to keep it as a separate class in PyTorch XLA instead of incorporating it into this PR (FSDP), since this is non-trivial to remove the parameter sharding in this PR above (the entire codebase assumes parameter sharding). For example, the sharded and full parameters are separate tensors throughout the codebase and it would be really messy to try to merge them or switch between them via a flag, which makes it hard to maintain this FSDP class or extend it with new features (such as introducing an option to overlay computation ops with collective ops like all-gather).

So let's keep FSDP to be FSDP :)

@ronghanghu
Copy link
Collaborator Author

@hjm-aws Also, regarding your earlier comment on adding process groups to allow mixed data parallelism and model parallelism -- this is something we are also planning to look at in Meta AI. So we could introduce a "group" option to this XLA FSDP in a future update after we first test it ourselves for mixed data + model parallelism

@hjm-aws
Copy link
Collaborator

hjm-aws commented May 9, 2022

@ronghanghu Hi Ronghang, we need to avoid the all-gather that happens before each forward. Do you have a suggestion for achieving that?

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented May 10, 2022

@ronghanghu Hi Ronghang, we need to avoid the all-gather that happens before each forward. Do you have a suggestion for achieving that?

Based on my understanding of ZeRO-2, I think an all-gather before the forward pass (or equivalently, after the sharded optimizer update of the previous iteration) is required in ZeRO-2. One needs to either put it in e.g. a sharded optimizer class that updates the parameter using sharded gradients and then broadcast the sharded updates via all-gather (as in SDP) or put it into the model before its forward pass (as in FSDP).

In the FSDP case, the model holds the sharded parameters in its idle state and only holds the full parameter during its forward and backward pass, while in the SDP case the model holds the full parameter throughout all the time and only generates updated sharded parameters during its sharded optimizer update step.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants