-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Do not wrap LoRA layers with FSDP #1538
base: main
Are you sure you want to change the base?
Conversation
When wrapping the full Transformer `Block`, FSDP wraps both trainable and non-trainable parameters. This results in way higher memory consumption, making the memory savings from LoRA meaningless. By instead wrapping only `torch.nn.Linear` modules, we still make use of FSDP but avoid wrapping the LoRA layers.
Thanks for the update @janEbert ! This looks good to me. Btw have you done a comparison (re memory usage) before and after by chance? |
I have not, it was only OOM vs. no OOM. 😅 I can supply some comparisons if that would help! |
I see, yeah I think we should do some comparisons to make sure it works as intended. If you want to do them, that'd be nice! I suggest perhaps with a small model (phi-2 or so) and a medium-sized model (e.g., Llama 3 8B) at least |
Will do that in the coming days! |
That'd be awesome. And pls let me know in case you need any help! |
Forgot to print GPU ranks, but the point should be clear. :) Before the change:
After the change:
Is this helpful enough or would you like to see more detailed stats? |
cc @awaelchli |
@janEbert Looks awesome, which model is that? I am also rerunning some of the models in the config hub and will update the numbers accordingly! |
I just ran a quick comparison on an 4xA10G machine to see if I can reproduce the config hub performance
For some reason, it's been really slow, but it can be a machine issue I have to look into. But here hare some numbers I am getting for the code in this PR:
If I compare it to the performance before, I notice that the non-step steps are about 30% slower. E.g. from the main branch:
Just curious, have you observed something similar? (In this case we maybe could also think about a "optimize runtime|memory" setting here. For comparison, the single-GPU speeds:
So I am thinking this could be due to a slow interconnect at the GPUs. I will look into it and do some more experiments. |
Why does the |
Not sure. I observed it with Phi-2 too: Main branch: litgpt finetune_lora checkpoints/microsoft/phi-2/ --devices 4
PR branch: litgpt finetune_lora checkpoints/microsoft/phi-2/ --devices 4
Something I need to investigate more in the next few days. I'll try this also on a different machine since I think the A10G machine has very slow GPU connections. |
I am curious if the whole Block was maybe accidentally trainable (instead of just the LoRA linear layers) before, which could explain the sharper loss decrease. But we should have tests for that, and I need to double-check that with a debugger. Just leaving this here as a note to myself so I can pick it up next week. |
Funny enough, in an entirely unrelated example, I've also noticed PyTorch Distributed becoming increasingly less reproducible for slightly changed settings the higher the PyTorch version. Could that maybe be the case here as well? Do you get a reproducible loss when running the same version of the code?
It's Mistral-7B-Instruct-v0.3 on a very small (4 samples) dummy JSON dataset, global batch size = 4, all other settings default. |
BTW my iteration speed is also slightly slower. I'll check if the version with |
That's a good point, but I think there is a different issue here that I am not understanding yet 😅. When I reran the code I observed basically the same higher loss. That's also independent of the model I tried. |
I ran into this as well but looking at the code of the lora blocks we will not be able to avoid mixed gradients within one block without rewriting the code there to not include frozen layers in the same block. If we wrap only frozen linear layers it is not clear at all how the optimizer (still fully sharded) is updating - maybe it doesn't properly, hence the loss not going down or maybe it does but who knows how the memory access is working in that case, especially with multiple nodes. |
Thanks for looking into this @TensorTemplar . I think that this may not be feasible then, so I am closing the PR for now. But happy to revisit this with other solutions in the future. |
If this is indeed what's happening, this would be a substantial oversight and would be really important to raise in PyTorch! Thanks a lot for giving this thought, though! I'm sorry I never managed to follow up. |
I haven't looked into the implementations (and even explanations), but just a spontaneous thought that we should perhaps retry this with FSDPv2 |
Sorry to clarify i was referring to the original |
When wrapping the full Transformer
Block
, FSDP wraps both trainable and non-trainable parameters. Because of how FSDP is implemented, this results in way higher memory consumption, making the memory savings from LoRA meaningless.By instead wrapping only
torch.nn.Linear
modules, we still make use of FSDP but avoid wrapping the LoRA layers.For clarity, this is the type of warning given by PyTorch for the old code: