-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] FSDP via SPMD #6379
Comments
Awesome thanks @alanwaketan ! One thing I did not find from the RFC is about the gradient clipping, which in FSDPv1 it requires to collect the global gradient norm before clipping. Wondering how it will be handled with GSPMD. |
Optimizer states will be sharded automatically. We just don't need to specify the sharding of them. Please refer to https://pytorch.org/blog/high-performance-llama-2/. In terms of gradient clipping, it's something under debut. I'm not quite sure whether we need it or not. Can you remind me what's the benefits of this function? I don't think we ever need that in TPU. |
We have found that in order to effectively utilize SPMD-FSDP, it is necessary to configure |
Can gradient checkpointing solve your issue? |
No, because the all-gather is for weights, not activations, it's unrelated to gradient checkpointing. The issue becomes more pronounced with larger models like llama2-70B; the increase in GPU memory is not very noticeable with smaller models. Additionally, in the Python implementation of FSDP, the all-gather operations are generated separately during the forward and backward passes, and are not merged together. |
So you are trying to prevent that consolidation between the fwd + bwd during the ckpted backward? As there is no way the original fwd's all-gather will be consolidated to the ckpted backward. |
Not sure if you're following the per-param fsdp rewrite- see PRs labeled [FSDP2]. Some ux changes if you want to try to stay consistent. |
@wconstab It will be great to stay consistent. On the other hand, I wonder if there is any design docs? |
@awgu can you link |
A majority of the info is captured here: pytorch/pytorch#114299 |
Thanks, will take a look! |
FSDP via SPMD (FSDP v2)
Introduction
FSDP, fully sharded data parallel, is a well-known distributed training algorithm in the PyTorch world. SPMD is PyTorch/XLA’s API that allows users to annotate a single device PyTorch model and then let XLA’s GSPMD feature turn it into a distributed model. This design doc focuses on how to utilize SPMD to express FSDP, and make this new implementation performant and easy to use. Since PyTorch/XLA already has a native implementation of FSDP here, this new implementation is also referred to as FSDPv2.
Background
A lot of past SPMD training experiments as conducted in has demonstrated that FSDP, i.e, 1D sharding has better performance than 2D sharding as long as the model can fit into the training fleet.
To express FSDP using the vanilla SPMD API, one currently needs to accomplish the following 5 steps. Examples are taken from our HF Llama 2 fork.
1. Define mesh
2. Shard data loader
3. Shard weights
4. Shard activations
5. Apply backward optimization barrier
Even though FSDP requires far less sharding annotations than 2D sharding which will require a lot more sharding on activations and additional sharding on the attention layers if for LLMs, it’s still complex. Here is the recap of what our native FSDP implementation, which we will refer to as FSDPv1 in the design, is being used in general:
It’s far less code boilerplate to write, and thus a much better user experience! Here comes the problem statement: could we recreate the same user experience while keeping the same performance with vanilla SPMD FSDP annotations?
Goals
Non-goals
Feature Requirements
Before talking about the design candidates, let’s detail what are the features and characteristics of this new system. This way, it can better help us navigate through the different design candidates.
P0: Shard on weights
This is the basic concept of FSDP where weights are sharded and distributed among the training fleet.
P0: Shard on activations
This is not needed in FSDPv1. However, if we omit this in the vanilla SPMD example in the Background section, we get much worse performance. The following performance benchmarks are taken with v4-8 and Llama 2 2B on 1K seq_len.
Shard activations
Don’t shard activations
It turns out that in the second case, the compiler decides to do some wild all-to-all and all-reduce in the attention layer.
Therefore, this feature is a must to instruct the compiler to follow the FSDP algorithm. Fortunately, we only need to shard either the input or the output hidden_states of the decoder layer in the case of LLM, and we don’t need to shard every activation.
P0: Backward optimization barrier
This is needed to prevent gigantic fusions on syncing the gradients. The only remaining question is whether it’s compatible with gradient checkpointing since both of them will overwrite the backward pass in some fashions. Theoretically speaking, it should be compatible and the application order shouldn’t matter.
P0: Manual wrapping
Most of the features here should be packaged together and can be applied separately to the root module and the children modules. Let’s take FSDPv1 as an example. Typically the wrapper will be applied to two modules:
Even though FSDPv1 will be default to shard all the parameters including children’s in the wrapped module, yet the rebuilding, memory-freeing, gradient synchronizing logic only applies to wrapped module. If only the root module is wrapped, then all parameters will be built in full during the outermost forward and thus no memory-saving. If every child module is wrapped, the overhead will just be too much. That’s why usually only the above two types of modules are wrapped.
P1: MultiSlice support
The implementation should be flexible to support 1) data parallel over MultiSlice and 2) FSDP over MultiSlice.
P1: Defer Parameter initialization
This is needed when the total model size is larger than the host memory. In TPU v5e, the host memory is extremely limited and this feature becomes a must. Basically, what we need to do is to initialize the model layer by layer, and transfer the layer to the device immediately.
P1: Auto wrapping
This refers to the ability to apply the same set of rules, e.g., sharding/opt-barrier/etc, automatically to children modules from the root module.
P1: Distributed checkpointing support
Two use cases here: one is for exception handling during the training job, and the other one is for consolidating to be used for future inference. For different design candidates, this feature requirement might have different implications.** For example, a nn.Module wrapper approach will introduce additional naming prefixes in the state_dicts.**
P1: HuggingFace and Lightning integrations
FSDPv2 should design in mind to easily replace the current FSDPv1 integrations presented in HuggingFace and Lightning, and thus become the default distributed algorithm in those two high level frameworks for PyTorch/XLA.
P2: Mixed precision support
In FSDPv1, it offers manual mixed precision support where the weights are always in FP32 but compute can be performed in BF16. Mixed precision support is definitely needed but whether it’s supported via torch.amp or via this is under discussion.
P2: Gradient Averaging
In FSDPv1, it offers a nice way of averaging the gradients by world_size to avoid overflows during all_reduce. Unclear whether this is necessary for this design.
N.A.: Shard on attentions
This is needed in the case of 2D sharding as we need to pick
num_attention_heads
dim to shard on themodel
axis. Since we only shard on the bs dim, theoretically speaking we shouldn’t need it. Experiments also validate the theory.N.A.: Replace nn.Linear
This is needed for 2D sharding as we don’t want PyTorch to collide the two dims of a tensor where both of them are sharded during a matmul operation. However, in FSDP, at most one dim of the tensor will be sharded, and therefore this is not needed. Here is the xprof that drops XLAPatchedLinear. No performance degradation is observed.
N.A.: Shard optimizer states
This is proven to be unnecessary during the 2D sharding exercise.
Design
In this section, two approaches will be discussed. Each of them will have a PoC implementation that includes all the P0 features to demonstrate the feasibility and pros & cons. Then one of them will be selected as the final design and then more P1 and P2 features will be added on top of.
As an nn.Module: SpmdFullyShardedDataParallel
This is a traditional approach like the FSDPv2. Here we have the following major components
init
It will take care of the following P0s:
Below is the pseudo code:
forward
This is the most debatable part. The forward function is used to shard on activations, specifically the output of the original module. As discussed in the above section, this is required to maintain the high performance. However, the output of the forward function can be anything, and therefore it’s really hard to shard.
Here is the proposed solution. Conventionally, the output usually will be:
For the 1st case, we can safely shard it. For the 2nd case, we can shard on the 0th element and warn the user to provide an output sharding function if that element is not intended. For all the other cases, we will just raise a runtime error to ask users to provide an output sharding function. That’s what
shard_output
in the above section represents.Code boilerplate
Some code boilerplate that is copied from the FSDPv1.
Manual wrapping
Here it follows the same logic as FSDPv1 where only the root module and important backbone children modules should be wrapped, like the decoder layers in LLMs. This is very important as we don’t want to register the backward optimization barrier for every linear. Otherwise, the compiler will lose a ton of optimization opportunities.
(WIP) As a function: spmd_fully_sharded_data_parallel
This will be very similar to the distribute_module, and in fact probably can be a high level wrapper on top of it. Basically, it will take care of everything within the function.
Test Plan
(WIP) PoCs Evaluation
SpmdFullyShardedDataParallel
Performance
Project Management
Milestone 1:
Milestone 2:
TBD, implement the rest of the feature set.
Open Questions
cc @JackCaoG @yeounoh @jonb377 @wconstab @wanchaol
The text was updated successfully, but these errors were encountered: