-
Notifications
You must be signed in to change notification settings - Fork 647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
8-bit optimizers dont work with FSDP #89
Comments
I looked at the deepspeed implementation before, which had a similar issue with shared weights. The problem was that the algorithm splits all tensors found in the optimizer state, which includes the quantization statistics. But this can lead to incorrect behavior. The workaround in deepspeed is to hide the quantization statistics by obscuring their type (putting the tensor into a list/tuple). I am not sure if the error message that you provided is related to that or not. It would be nice if we could get 8-bit Adam working for FSDP. Would you be able to provide a simple example for debugging and replication purposes? Since I will be pretty busy the next month, I would also be very happy to guide you on how to fix this if you create a PR and provide me with error messages / stack traces. I think it would be pretty useful since more and more people are using FSDP. |
Hey @TimDettmers, I created a gist with an example. The gist includes a https://gist.github.com/philschmid/99410e8bf66d34e52bb0cd5270b07989 I hope that's enough for you to test it. |
I tested the example i shared with the AdamWInt8 {'loss': 2.6643, 'learning_rate': 4.847094801223242e-05, 'epoch': 0.09}
{'loss': 2.752, 'learning_rate': 4.694189602446483e-05, 'epoch': 0.18}
{'loss': 3.1493, 'learning_rate': 4.541284403669725e-05, 'epoch': 0.28}
{'loss': 3.412, 'learning_rate': 4.3883792048929664e-05, 'epoch': 0.37}
{'loss': 3.6722, 'learning_rate': 4.0825688073394495e-05, 'epoch': 0.55} Adafactor {'loss': 2.8385, 'learning_rate': 4.847094801223242e-05, 'epoch': 0.09}
{'loss': 2.6384, 'learning_rate': 4.694189602446483e-05, 'epoch': 0.18}
{'loss': 2.5725, 'learning_rate': 4.541284403669725e-05, 'epoch': 0.28}
{'loss': 2.5757, 'learning_rate': 4.3883792048929664e-05, 'epoch': 0.37}
{'loss': 2.5297, 'learning_rate': 4.0825688073394495e-05, 'epoch': 0.55} |
Hi @TimDettmers in my latest test, it turns out that saving the model is the source of this issue. Specifically the error pops up when I run this: optim_state = FSDP.full_optim_state_dict(model, optimizer) What this is supposed to do is assemble the entire optimizer based on the model params. Now what I think is the problem is that the optimizer is in 8-bit but the model is not in 8-bit. The reason for my assumption is the error is thrown by File "/share03/draj/environments/.conda/envs/yanmtt/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2136, in _all_gather_base Indeed if you look here: https://github.com/pytorch/pytorch/blob/55daa835e97a6e742cba1f0e9d2a5c78b1615e99/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L2779 Then there is a constraint that the dtypes of tensors should be the same and we are not able to guarantee this for a sharded 8-bit optimizer. If we can find some way to bypass this requirement, then we are good to go. How do we overcome this issue? |
I have the same issue. #323 |
There is another issue.
|
Im not a 100% sure but this might be taken care of in pytorch 2.0. |
I encountered a similar issue using PEFT LoRA, load_in_8bit, and DeepSpeed 3 (optimizer and params offload) with huggingface accelerator. on a single gpu, training was fine as expected. If anyone found a workaround to enable parallel training with PEFT LoRA and load_in_8bit, please let me know. |
it seems that pytorch 2 doesnot support 8bit |
anyone still working on this....? on the error @prajdabre was mentioning, I find that the problem does not come from a dtype mismatch, but rather a size mismatch. With printf debugging, I noticed that this seemed to first error on the absmax1 value, with output_tensor.shape == Size([361496576]), output_tensor.dtype == float32
input_tensor.shape == Size([22064]), input_tensor.dtype == float32 |
cc @awgu |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Noting that this issue, although stale, remains an issue. Although optimization can run, a functional state dict cannot be saved with 8bitadam. I notice that there is a PR for FSDP functionality in #840. It generally does not address the state dict issue in its tests. |
@Titus-von-Koeller @TimDettmers sorry to hijack this issue. Doing something related but not exactly the same. im trying to use FSDP with
The FSDP wrapping will fail at |
@152334H when you were trying this, did you load the model in 4/8b precision? or the model is in 32b precision, but you want to activate |
? I do not test via huggingface. I was in fact trying to only use an 8bit optimiser with 32bit weights, though, so I do not experience the int8 flatparameter issue you do. |
Hey @152334H @fabianlim @HamidShojanazeri @prajdabre @Kyeongpil @hscspring @dotsnangles @philschmid, Could some of you please retest this and let us know if the particular problems that you were observing persist in the same for or if different, please put forward detailed logs + description? We just released official FSDP support in the latest BNB version. However, this release was not focused on 8-bit optimizer support, yet. Be sure to install with
|
@Titus-von-Koeller @TimDettmers I think the problem still remains even with BNB 0.43. The reason is because BNB performs optimizer steps with CUDA.
prev_device = pre_call(g.device)
def pre_call(device):
prev_device = torch.cuda.current_device()
torch.cuda.set_device(device)
return prev_device
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) Thus while one can move all of the above quanities to gpu -> compute -> cpu. Im not sure if this is the most optimal way to do things as it will involve a lot of IO overhead. |
@fabianlim Yes, you're right! Thanks for the detailed analysis, this really helps make things actionable. I'll put it on my list of things to look into, but can't promise a timeline. We have a lot on our plate in the immediate future, as there are a lot of necessary changes that need to be prioritize to make BNB more maintainable and easier to contribute to. In case you're interested to work with us on finding a solution, we would be super happy to collaborate and support you in any way! |
@Titus-von-Koeller On one hand, we can workaround this by loading all the quantities onto GPU, but this will be very inefficient. On the other hand, I feel the better approach would be to run the optimizer step alongside the FSDP sharding. As we see here, the optimizer step can be run after the FSDP post grad hook. There is a comment there to say that for CPU offload the parameters and gradients are run on CPU, but this should not be the case. If during offload, we can run the optimizer step in GPU before it gets offloaded, then this solves our problem and we do not need to shuffle params around I have posted a comment on pytorch asking when FSDP will start to support running |
I have a repro using PL DDP: Here is a semi-minimal repro (smallest I could get it). Breaking on 2xA100. Repro gist: https://gist.github.com/isaacbmiller/fc871d732d4d6a6b7ede3190a6979f40 nvidia-smi
deps
|
I don't have time to work on this in the next weeks due to needing to prioritize the I see the big benefit of enabling this use-case and will prioritize it relatively high in the next months. I'll use this thread to keep you posted. Thanks for the minimal repro, really appreciated @isaacbmiller (those are always really useful!). ❤️ |
EDIT: ignore the below, it does not seem to work as expected after all. Interesting behavior: initially it seems to work, but after saving and reloading the checkpoint, I get an error of mismatching types. Just noting here that 8 bit adamw seems to work for me on FSDP with the following accelerate config:
|
Anyone knows alternative to bitsandbytes that we can use as a drop-in replacement until this gets fixed? |
@musabgultekin did you find any replacement?? bnb 8 bit optims don't work with FSDP1 and FSDP2 |
No unfortunately. |
Use torchtune with torch2.5 for lower memory requirements, if you're fine-tuning LLMs |
@musabgultekin did you explore torchao low-bit optims? to work with FSDP2? |
When I use an 8-bit ADAM with FSDP, I get an error as follows:
RuntimeError: output tensor must have the same type as input tensor
If my understanding is correct, there seems to be a casting issue. Is there any workaround this?
TIA.
The text was updated successfully, but these errors were encountered: