-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Error while training with Deepspeed #4295
Comments
I ran into this same exact issue as well. |
any solution? |
Some code for ZeRO3 assumes that all parameters in a model has the same dtype. This model has |
Do you fix this problem for now? |
I have the same issue. I've attached my deepspeed config file. I'm running my training off the Axolotl library. |
I submitted #4647 to address this issue. It is working on my environment. |
This PR addresses an error reported in #4295. When parameters in multiple data types are given, DeepSpeed performs allgather for each data type. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Hi tohtana, I found the issue, I changed to your code, but the training was good, model.save_pretrained(my_model) -> adapter_model.bin size -> 163KB. I think the weight of LoRA was not saved. How can I solve this problem?
|
Hi @momozzing, can you share the code to reproduce this? |
Ok, My baseline model is LLAMA. Zero stage 2 works well with this code. However, zero stage 3 does not work. Code
ds_config_zero3
ds_config_zero2
|
Hi @momozzing, Also, the error you mentioned seems to be distinct from the initial problem. If it persists, I suggest creating a new issue to address it. |
Hi @tohtana, I'm using this code but, save_checkpoint only saves the optimizer state, model state is not saved. -rw-rw-r-- 1 519K 09:32 zero_pp_rank_0_mp_rank_00_model_states.pt When I save the trained model, Is there any way to save LoRA's trained weight?? |
Hi @momozzing, You can find an example of the combination of ZeRO3 and LoRA in DeepSpeed-Chat. In the following example, it saves all the parameters including ones for LoRA. |
Hi @tohtana LLAMA + QLoRA without deepspeed stores the size of the adapter_model.bin at 477MB. LLAMA + QLoRA with deepspeed zero2 stores the size of the adapter_model.bin at 477MB. but, LLAMA + QLoRA with deepspeed zero3 stores the size of the adapter_model.bin at 519K. so, There seems to be an issue where the parameter size of LoRA is saved as torch.Size([0]). Is there any way to save LoRA's trained weight with deepspeed zero3?? Does deepspeed zero3 support bitsandbytes? |
Hi @momozzing ZeRO3 sets an empty size (Size([0]) to a parameter object and has real tensor data in a different attribute. We cannot say that parameters are not saved even when we see Here is another example using HF trainer and LoRA. This script seems to save parameters properly. Can you check this as well? |
Hi, @tohtana Here's how I solved it.
Thank you very much for your reply. |
This PR addresses an error reported in microsoft#4295. When parameters in multiple data types are given, DeepSpeed performs allgather for each data type. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
this is a workaround, not a proper solution as this can be really expensive:
This is the efficient way of doing that as it'd gather one layer at a time and incur little memory overhead. |
|
Describe the bug
Deepspeed runs into a bug while training a CodeLlama-34B model with QLoRA using this script
To Reproduce
Run the script with deepspeed file passed into the params. The deepspeed config i used is given below:
Expected behavior
Expected behaviour is deepspeed training without any errors. The following error (
RuntimeError: expected there to be only one unique element in <generator object Init._convert_to_deepspeed_param.<locals>.all_gather_coalesced.<locals>.<genexpr> at 0x7ff729d61cb0>
) pops up with the traceback as given belowds_report output
DeepSpeed C++/CUDA extension op report
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
JIT compiled ops requires ninja
ninja .................. [OKAY]
op name ................ installed .. compatible
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
[WARNING] using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.10/dist-packages/torch']
torch version .................... 2.0.1+cu118
deepspeed install path ........... ['/usr/local/lib/python3.10/dist-packages/deepspeed']
deepspeed info ................... 0.10.3+542dc0d5, 542dc0d, master
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 11.8
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.8
shared memory (/dev/shm) size .... 188.00 GB
Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
Launcher context
used deepspeed launcher with huggingface integration
Docker context
Are you using a specific docker image that you can share?
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: