-
Notifications
You must be signed in to change notification settings - Fork 498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Fully Sharded Data Parallel (FSDP) in PyTorch XLA #3431
Conversation
7922759
to
f15140b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome PR. Can't wait to use it!
I successfully ran FSDP to train |
@ronghanghu Hi Ronghang, just to confirm I read the code correctly: we don't have prefetch logic here, right? By prefetching I mean pre- |
@hjm-aws You are right, we don't have a prefetch logic here. The all-gather of the next layer happens after the previous layer is completed (instead of happening in parallel). This is consistent with FairScale's FSDP implementation (and also saves more memory compared to parallel prefetch). |
As a temporary workaround to #3455 (comment), I added a This is a makeshift and inefficient workaround to avoid XLA compiler fusion that breaks parameter freeing in nested FSDP (ZeRO-3). It is useful only when This workaround notably increases the execution time, triggers more compilation, and does not cover all the cases, so we need a permanent solution to #3455 (comment). I'll remove the (It seems that the CI build failed due to issues unrelated to this PR; this PR does not touch any C++ codebase.) |
…duce` configs; unpin layout for all_reduce; add a wrapper for gradient checkpointing on modules; remove redundant `param_names`
…ard when specified
… actual speed up in tests and breaks optimizer groups
a5855a1
to
191ac9b
Compare
(rebased against master branch) |
Update: our internal scaling tests show that we can scale up to transformer models with 60 billion parameters on v3-512 using this FSDP PR. Then we encounter the issue of #3453 and cannot scale it to even larger model sizes.
Before there is a long-term fix for #3453, we may want to try some short-term workarounds (e.g. inserting |
This is exciting! Thanks @ronghanghu ! |
This is awesome @ronghanghu! Thank you! |
@miladm I think this PR should be ready for merge in its current shape once our internal accuracy tests pass. So far the loss curves of these accuracy tests look good but I'm still waiting to confirm their final results (should know by tomorrow). The current PR commits above are exactly what we used internally at Meta AI in our scaling runs (including our ViT and other experiments). Regarding the ViT scaling example, I would like to keep it as a separate repository in https://github.com/ronghanghu/vit_10b_fsdp_example for now because 1) it involves many open research questions on how to best train ViTs at this 10~60 billion parameter size (such as what optimizer to use, what learning rate to use, how to avoid training divergence) where we don't have an optimal recipe like the ImageNet or MNIST examples above yet, 2) it's more like an example of downstream application rather than features of PyTorch XLA, and 3) it relies on external dependencies such as So I think it's better to have a separate repository for scaling examples on top of PyTorch XLA as opposed to having them as part of PyTorch XLA itself. This is similar to keeping a separate T5X repository on top of JAX. |
Thanks @ronghanghu. I agree, it should be on a separate repo. Thanks for sharing the reference to it on this PR. |
@ronghanghu Feel free to merge the PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Merging the PR.
@ronghanghu Hi Ronghang, we (AWS) plan to add ZeRO-2 optimization into your FSDP implementation. ZeRO-2 is just ZeRO-3 minus parameter sharding. We plan to add a flag to the constructor, and when the flag is |
Hi @hjm-aws, you can already do ZeRO-2 by wrapping the entire model with a single (instead of nested) FSDP class with The FAIRScale library originally introduced an older SDP class It is actually the recommended way to do ZeRO-2 using the newer If one really needs the alternative implementation of ZeRO-2 similar to So let's keep FSDP to be FSDP :) |
@hjm-aws Also, regarding your earlier comment on adding process groups to allow mixed data parallelism and model parallelism -- this is something we are also planning to look at in Meta AI. So we could introduce a "group" option to this XLA FSDP in a future update after we first test it ourselves for mixed data + model parallelism |
@ronghanghu Hi Ronghang, we need to avoid the all-gather that happens before each forward. Do you have a suggestion for achieving that? |
Based on my understanding of ZeRO-2, I think an all-gather before the forward pass (or equivalently, after the sharded optimizer update of the previous iteration) is required in ZeRO-2. One needs to either put it in e.g. a sharded optimizer class that updates the parameter using sharded gradients and then broadcast the sharded updates via all-gather (as in SDP) or put it into the model before its forward pass (as in FSDP). In the FSDP case, the model holds the sharded parameters in its idle state and only holds the full parameter during its forward and backward pass, while in the SDP case the model holds the full parameter throughout all the time and only generates updated sharded parameters during its sharded optimizer update step. |
Fully Sharded Data Parallel (FSDP) in PyTorch XLA
This PR implements Fully Sharded Data Parallel (FSDP) in PyTorch XLA for sharding Module parameters across data-parallel workers.
Example usage:
It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters.
Notes:
XlaFullyShardedDataParallel
class supports both the ZeRO-2 optimizer (sharding gradients and optimizer states) and the ZeRO-3 optimizer (sharding parameters, gradients, and optimizer states) in https://arxiv.org/abs/1910.02054.reshard_after_forward=True
. Seetest/test_train_mp_mnist_fsdp_with_ckpt.py
andtest/test_train_mp_imagenet_fsdp.py
for an example.FSDPViTModel
for an example.checkpoint_module
is provided (based ontorch_xla.utils.checkpoint.checkpoint
from Add Checkpoint api to use optimization barrier #3524) to perform gradient checkpointing over a givennn.Module
instance. Seetest/test_train_mp_mnist_fsdp_with_ckpt.py
andtest/test_train_mp_imagenet_fsdp.py
for an example.optimizer.step
and do not callxm.optimizer_step
. The latter reduces the gradient across ranks, which is not needed for FSDP (where the parameters are already sharded).master_only=False
and set different paths for each rank inxm.save
). When resuming, it needs to load the checkpoint for the corresponding rank.model.get_shard_metadata()
along withmodel.state_dict()
as follows and useconsolidate_sharded_model_checkpoints
to stitch the sharded model checkpoints together into a full model state dict. Seetest/test_train_mp_mnist_fsdp_with_ckpt.py
for an example.The implementation of this class is largely inspired by and mostly follows the structure of
fairscale.nn.FullyShardedDataParallel
in https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html. One of the biggest difference fromfairscale.nn.FullyShardedDataParallel
is that in XLA we don't have explicit parameter storage, so here we resort to a different approach to free full parameters for ZeRO-3.Example training scripts on MNIST and ImageNet
test/test_train_mp_mnist_fsdp_with_ckpt.py
(it also tests checkpoint consolidation)test/test_train_mp_imagenet_fsdp.py
Installation
Update (05/05/2022): This XLA FSDP PR requires the 20220505 nightly XLA wheels and 20220413 libtpu. For TPU VMs, start from
tpu-vm-pt-1.10
runtime and install the nightly wheels viaTrain MNIST on v3-8 TPU
It gets around 98.9 accuracy for 2 epochs:
python3 -u ~/xla_fsdp_dev/test/test_train_mp_mnist_fsdp_with_ckpt.py \ --batch_size 16 --drop_last --num_epochs 2 \ --use_nested_fsdp --use_gradient_checkpointing
This script automatically tests checkpoint consolidation at the end. You can also manually consolidate the sharded checkpoints via
Train ImageNet with ResNet-50 on v3-8 TPU
It gets around 75.9 accuracy for 100 epochs; download ImageNet-1k to
/datasets/imagenet-1k
:python3 -u ~/xla_fsdp_dev/test/test_train_mp_imagenet_fsdp.py \ --datadir /datasets/imagenet-1k --drop_last \ --model resnet50 --test_set_batch_size 64 --eval_interval 10 \ --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \ --use_nested_fsdp
You can also add
--use_gradient_checkpointing
(which needs to be used along with--use_nested_fsdp
) to apply gradient checkpointing on the residual blocks.Example training scripts on TPU pod (ViT with 10 billion parameters)
To train large models that cannot fit into a single TPU, one should use nested FSDP (wrapping sub-modules with inner FSDP when building the entire model) to implement the ZeRO-3 algorithm.
Please see https://github.com/ronghanghu/vit_10b_fsdp_example for an example of sharded training of a Vision Transformer (ViT) model using this XLA FSDP PR.
TODOs
[ ] Speed-memory tradeoff analysis and profiling(partially blocked by XLA profiler cannot capture TPU device trace when running on a pod #3446).Related issues affecting this PR: pytorch/pytorch#74424, #3330, #3392, #3423, #3441, #3446, #3453, #3455, #3502, #3506, #3510, #3545