Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support FSDP #358

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from
Draft

[WIP] Support FSDP #358

wants to merge 37 commits into from

Conversation

mehdidc
Copy link
Contributor

@mehdidc mehdidc commented Jan 17, 2023

This PR adds FSDP (https://pytorch.org/docs/stable/fsdp.html) support for training large models than cannot
fit in memory.

The code works already, but still need to be improved, so this is still a draft.

Some scaling plots with sample per sec per gpu, done on JUWELS Booster.

G-14:
G-14

X-14 (15B visual, 5B text):
X-14

I also tried G-14 as visual encoder together with a pre-trained T5-XXL as text encoder.

Putting again some remarks and possible improvements, discussed earlier in discord:

  • I see some hanging issues starting from large number of nodes (256 nodes, 1024 GPUs onJUWELS Booster), no single iteration, I don't see anything special on NCCL debugging info except that it does a lot of all_gather,
    which is expected from FSDP
  • CPU offloading and gradient checkpointing are supported
  • each time encode_image or encode_text or logit_scale were accessed without going through the forward function (so happens when clipping logit scale, or at evaluation), an exception was raised (see [FSDP] caffe2 error in forward method when using fsdp pytorch/pytorch#82461 for reference). The workaround I found is to modify the forward function so that it is possible to encode both text and image (as currently done), or text only, or image only, or use it for clipping logit scale. It would be better if we find a cleaner solution. The solution provided by the issue in pytorch above is to wrap the modules (here text and image encoders) using FSDP, but we need then to change some internals, as part of the text encoder cannot be wrapped as it is an nn.Parameter, FSDP needs an nn.Module. We could use CustomTextCLIP to wrap the text encoder in its entirty as proposed by @rwightman, then we need to deal with logit_scale.
  • The list of layers to FSDP-wrap is important as it affects the peak memory (documented here https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html#id2) and the layer names are model dependent. e.g. in T5, I am FSDP wrapping T5Block, and in CLIP class of I am wrapping ResidualAttentionBlock. So we need a way to parametrize this, for the moment they are hardcoded. If we just use the default auto policy of FSDP, we get OOM.

@mehdidc mehdidc marked this pull request as draft January 17, 2023 15:27
@orchidmajumder
Copy link

One thing that will break in this implementation is the ability to set separate WD for LN parameters or the bias e.g. here -

exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
. Because FSDP will wrap the VisionTransformer or TextTransformer into one FSDP block, these names will not be retained and the filter will fail silently. You can print the layer names after it is wrapped in FSDP to verify this.

To make sure FSDP retains the original name, you will need to pass an additional field in the FSDP constructor called use_orig_params=True. See my discussion in PyTorch forum here - https://discuss.pytorch.org/t/setting-different-weight-decay-values-for-parameters-within-one-fsdp-unit/169862/2. This feature needs PyTorch nightly.

@rwightman
Copy link
Collaborator

rwightman commented Jan 26, 2023

@orchidmajumder good point, I believe that arg is also required if we want to use torch.compile with FSDP, at least in its current state

@rom1504
Copy link
Collaborator

rom1504 commented Jan 30, 2023

Can you rebase on master ?

@mehdidc
Copy link
Contributor Author

mehdidc commented Jan 31, 2023

Thanks @orchidmajumder @rwightman will look into that ! @rom1504 just rebased.

@mehdidc
Copy link
Contributor Author

mehdidc commented Feb 2, 2023

Update: layer names to FSDP-wrap are not hardcoded anymore, they can now be provided in the CLI with defaults that will work already with models we have.

@mehdidc
Copy link
Contributor Author

mehdidc commented Feb 4, 2023

Update: following this thread huggingface/accelerate#807, full / partial locking now works. Currently getting some throughput numbers with mt5-xxl-ViT-G-14

@mehdidc
Copy link
Contributor Author

mehdidc commented Feb 9, 2023

Update: I mentioned earlier that training was hanging with large nodes (e.g., 256 on JUWELS Booster), after checking lower nb of nodes, it seems that the starting up phase (before displaying the first "INFO | Train Epoch") duration is long and proportional to nb of nodes, which is problematic. e.g. for 128 nodes, the starting up phase takes 24mins, and in 64 nodes it takes 11mins. Will open an issue pytorch. So it is probably not properly hanging for 256, I just did not run it for long enough, but it's a lot of time if it would take 48mins.

starting_up
starting_up_2

This is GPU usage for a 128 nodes run of mt5-xxl-ViT-G-14.
Small GPU usage for the first 24 mins, then it starts to have > 99% usage and that coincides with the first "INFO | Train Epoch" message in the logs.

@nkflash
Copy link

nkflash commented Feb 16, 2023

Hi @mehdidc, Base on your code I try to use ViT-e-14 model, openCLIP will hang after first epoch step with FSDP enable.
Do you meet same issue?

@mehdidc
Copy link
Contributor Author

mehdidc commented Feb 17, 2023

hey @nkflash, thanks I actually noticed that as well, even with smaller models, I am on it.

EDIT: found a fix, will push soon

@mehdidc
Copy link
Contributor Author

mehdidc commented Feb 18, 2023

@nkflash pushed, could you please try again? I can confirm that it worked for me

@mehdidc
Copy link
Contributor Author

mehdidc commented Feb 18, 2023

Thanks, @orchidmajumder , use_orig_params is working as expected. So with pytorch nightly, we can already use it. If we want to also support current pytorch stable version (1.13), wrapping layer norms into in their own FSDP units using the option I added --fsdp-layers-to-wrap would also work, but it won't handle other cases, e.g. biases from MLP layers, we need also to wrap them as well separately, so I am not sure we would be able to support current pytorch current stable (1.13) without more complications in the code, I think for now we just need to document that to the user (except if we find a better solution), i.e. the closest that can be done to get the "correct" behavior is to FSDP-wrap layer norms, but in that case biases from MLPs will be decayed, otherwise one needs pytorch nightly, or next stable version.

The other thing that needs to be changed is:

exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n

since FSDP flattens everything, p.ndim is always < 2, so everything would be excluded in the current code, which means
everything will be weight-decayed.
I found out that for ViTs at least, the only additional case that p.ndim < 2 covers is visual.class_embedding (

self.class_embedding = nn.Parameter(scale * torch.randn(width))
),
the rest is covered by the other clauses.
Or, is this supposed to cover something else? @rwightman @rom1504 @mitchellnw @gabrielilharco
What about making exclusions parametrizable, e.g. with regexps ? perhaps to be more explicit about the parameters
to decay.

@mitchellnw
Copy link
Contributor

the p.ndim < 2 check should also cover logit_scale

@mehdidc
Copy link
Contributor Author

mehdidc commented Feb 18, 2023

Yes was thinking of that as well but saw that there is already 'logit_scale' in n in exclude

@rwightman
Copy link
Collaborator

@mehdidc the position, token class embeddings are typically not decayed as well but looks like that was never done in OpenCLIP, hrmm. I have no_weight_decay methods in timm that return lists of names to exclude from decay to cover the dim >= 2 cases like position embeddings, etc.
https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L506
https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/maxxvit.py#L1229

I feel that name based decay by itself is error prone and won't generalized well to other models (like timm vision towers), the unflattened dim is the strongest signal (but sometimes you need to use names for some layers like embeddings, etc).

I feel the best approach would be to build a set of names (from both shape and name) to not decay before wrapping in FSDP when that info is available, so move the current code a bit earlier.

@rwightman
Copy link
Collaborator

On the other topic, I feel it's fine to support nightlies only, I'm already exclusively using nightlies to train the convnext models because it's the only way to get decent bfloat16 support for convolutions. I'm going to add torch.compile soon to see how that works, the nn.MHA is quite a bit faster on nightlies (has a fused kernel).

@mehdidc
Copy link
Contributor Author

mehdidc commented Feb 19, 2023

@rwightman Thanks for the suggestion, I moved the code a bit earlier now, now it is fixed.

@mehdidc mehdidc force-pushed the fsdp branch 3 times, most recently from 1347816 to b118c14 Compare February 19, 2023 10:18
@mehdidc
Copy link
Contributor Author

mehdidc commented Feb 19, 2023

Update:@rwightman @rom1504 @mitchellnw @gabrielilharco @JeniaJitsev just for info, regarding the starting up phase I mentioned earlier (#358 (comment)), I found out that it is not only proportional to nb of nodes but also model size, but found a fix. Read below if you want more info.

So e.g. with 256 nodes on JUWELS Booster, it took 13mins for ViT-B/32, 16mins for ViT-L/14, and 28mins for ViT-g/14, that is a lot of waste of time.
After checking the trace, I saw that it was not hanging, stuff was happening, and found that FSDP does something special on the first forward pass (https://github.com/pytorch/pytorch/blob/85e0fd0280948a342a916429448fed2486e82aa5/torch/distributed/fsdp/_exec_order_utils.py#L210). After profiling, I found out that they have two for loops (https://github.com/pytorch/pytorch/blob/85e0fd0280948a342a916429448fed2486e82aa5/torch/distributed/fsdp/_exec_order_utils.py#L235) which basically take (for ViT-B/32) 12secs each, and it is done for each FSDP unit (nb of FSDP units is proportional model size as we FSDP wrap residual blocks). If you count the total, you get the explanation of why. The for loops iterate over all pair of ranks basically, about 1M total for the two loops, that shouldn't be slow, the thing is that there is a repeated access to a Tensor that is in GPU (https://github.com/pytorch/pytorch/blob/85e0fd0280948a342a916429448fed2486e82aa5/torch/distributed/fsdp/_exec_order_utils.py#L237) which slow things down. I will open an issue/PR, a simple .cpu() before the for loop solves the problem, it's then a matter of seconds.

@mehdidc mehdidc changed the title Support FSDP [WIP] Support FSDP Feb 19, 2023
@orchidmajumder
Copy link

I have also observed the delay with FSDP on AWS clusters and actually thought FSDP hangs over a certain number of nodes and didn't pursue it further - thanks for the amazing deep-dive @mehdidc .

@nkflash
Copy link

nkflash commented Feb 20, 2023

@nkflash pushed, could you please try again? I can confirm that it worked for me

I checkout the head code, it works well now

@mehdidc
Copy link
Contributor Author

mehdidc commented Mar 6, 2023

Update: as the problem with large nodes is solved, following are updated scaling plots up to 1024 GPUs:

G-14:

G14

I also tested freezing a subset of layers, with MT5-XXL as text encoder (5 last blocks trainable, rest is frozen), G-14 as visual encoder (last block is trainable, rest is frozen), patch dropout 0.5
mt5

@mehdidc
Copy link
Contributor Author

mehdidc commented Mar 12, 2023

Update: the first fully trained model with FSDP is finished, I started with a ViT-B/32 on LAION-400M , 32 epochs (96 gpus, local bs of 896, global bs of 86016, lr of 0.001), zero-shot accuracy in ImageNet is 63.6% with ~90K samples/s throughput. Training was done using pytorch-nightly (torch-2.0.0.dev20230218+cu117).

{"dataset": "wds/imagenet1k", "model": "ViT-B-32", "pretrained": "epoch_32.pt", "task": "zeroshot_classification", "metrics": {"acc1": 0.63606, "acc5": 0.87912, "mean_per_class_recall": 0.6360399999999999}, "language": "en"}

index

it's similar to what we get in https://arxiv.org/pdf/2212.07143.pdf (Table 13)

…eter names to avoid erroneous parameter decay, and decay params by constructing a set of parameter names to decay before FSDP wrapping (thanks to @rwightman)
…ict and we shard the optim state dict after loading
…y from pytorch nightly

- fix grad checkpointing offloading to be compatible with pytorch nightly
- use sync_module_states
- Only import FSDP  modules if possible to avoid import error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants