-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Support FSDP #358
base: main
Are you sure you want to change the base?
[WIP] Support FSDP #358
Conversation
One thing that will break in this implementation is the ability to set separate WD for LN parameters or the bias e.g. here - open_clip/src/training/main.py Line 270 in 16e229c
To make sure FSDP retains the original name, you will need to pass an additional field in the FSDP constructor called |
@orchidmajumder good point, I believe that arg is also required if we want to use torch.compile with FSDP, at least in its current state |
Can you rebase on master ? |
Thanks @orchidmajumder @rwightman will look into that ! @rom1504 just rebased. |
Update: layer names to FSDP-wrap are not hardcoded anymore, they can now be provided in the CLI with defaults that will work already with models we have. |
Update: following this thread huggingface/accelerate#807, full / partial locking now works. Currently getting some throughput numbers with |
Update: I mentioned earlier that training was hanging with large nodes (e.g., 256 on JUWELS Booster), after checking lower nb of nodes, it seems that the starting up phase (before displaying the first "INFO | Train Epoch") duration is long and proportional to nb of nodes, which is problematic. e.g. for 128 nodes, the starting up phase takes 24mins, and in 64 nodes it takes 11mins. Will open an issue pytorch. So it is probably not properly hanging for 256, I just did not run it for long enough, but it's a lot of time if it would take 48mins. This is GPU usage for a 128 nodes run of |
Hi @mehdidc, Base on your code I try to use ViT-e-14 model, openCLIP will hang after first epoch step with FSDP enable. |
hey @nkflash, thanks I actually noticed that as well, even with smaller models, I am on it. EDIT: found a fix, will push soon |
@nkflash pushed, could you please try again? I can confirm that it worked for me |
Thanks, @orchidmajumder , The other thing that needs to be changed is: open_clip/src/training/main.py Line 271 in 6ed7dd6
since FSDP flattens everything, open_clip/src/open_clip/transformer.py Line 366 in 6ed7dd6
the rest is covered by the other clauses. Or, is this supposed to cover something else? @rwightman @rom1504 @mitchellnw @gabrielilharco What about making exclusions parametrizable, e.g. with regexps ? perhaps to be more explicit about the parameters to decay. |
the |
Yes was thinking of that as well but saw that there is already |
@mehdidc the position, token class embeddings are typically not decayed as well but looks like that was never done in OpenCLIP, hrmm. I have I feel that name based decay by itself is error prone and won't generalized well to other models (like timm vision towers), the unflattened dim is the strongest signal (but sometimes you need to use names for some layers like embeddings, etc). I feel the best approach would be to build a |
On the other topic, I feel it's fine to support nightlies only, I'm already exclusively using nightlies to train the convnext models because it's the only way to get decent bfloat16 support for convolutions. I'm going to add torch.compile soon to see how that works, the nn.MHA is quite a bit faster on nightlies (has a fused kernel). |
@rwightman Thanks for the suggestion, I moved the code a bit earlier now, now it is fixed. |
1347816
to
b118c14
Compare
Update:@rwightman @rom1504 @mitchellnw @gabrielilharco @JeniaJitsev just for info, regarding the starting up phase I mentioned earlier (#358 (comment)), I found out that it is not only proportional to nb of nodes but also model size, but found a fix. Read below if you want more info. So e.g. with 256 nodes on JUWELS Booster, it took 13mins for ViT-B/32, 16mins for ViT-L/14, and 28mins for ViT-g/14, that is a lot of waste of time. |
I have also observed the delay with FSDP on AWS clusters and actually thought FSDP hangs over a certain number of nodes and didn't pursue it further - thanks for the amazing deep-dive @mehdidc . |
I checkout the head code, it works well now |
Update: the first fully trained model with FSDP is finished, I started with a ViT-B/32 on LAION-400M , 32 epochs (96 gpus, local bs of 896, global bs of 86016, lr of 0.001), zero-shot accuracy in ImageNet is 63.6% with ~90K samples/s throughput. Training was done using pytorch-nightly (
it's similar to what we get in https://arxiv.org/pdf/2212.07143.pdf (Table 13) |
…eter names to avoid erroneous parameter decay, and decay params by constructing a set of parameter names to decay before FSDP wrapping (thanks to @rwightman)
…ict and we shard the optim state dict after loading
…y from pytorch nightly - fix grad checkpointing offloading to be compatible with pytorch nightly - use sync_module_states
- Only import FSDP modules if possible to avoid import error
…f adding args.distributed_engine
This PR adds FSDP (https://pytorch.org/docs/stable/fsdp.html) support for training large models than cannot
fit in memory.
The code works already, but still need to be improved, so this is still a draft.
Some scaling plots with sample per sec per gpu, done on JUWELS Booster.
G-14:
X-14 (15B visual, 5B text):
I also tried G-14 as visual encoder together with a pre-trained T5-XXL as text encoder.
Putting again some remarks and possible improvements, discussed earlier in discord:
all_gather
,which is expected from FSDP
encode_image
orencode_text
orlogit_scale
were accessed without going through the forward function (so happens when clipping logit scale, or at evaluation), an exception was raised (see [FSDP] caffe2 error in forward method when using fsdp pytorch/pytorch#82461 for reference). The workaround I found is to modify the forward function so that it is possible to encode both text and image (as currently done), or text only, or image only, or use it for clipping logit scale. It would be better if we find a cleaner solution. The solution provided by the issue in pytorch above is to wrap the modules (here text and image encoders) using FSDP, but we need then to change some internals, as part of the text encoder cannot be wrapped as it is annn.Parameter
, FSDP needs annn.Module
. We could useCustomTextCLIP
to wrap the text encoder in its entirty as proposed by @rwightman, then we need to deal withlogit_scale
.