Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error using deepspeed zero2 + load_in_8bit + lora #874

Merged
merged 2 commits into from
Aug 31, 2023

Conversation

tmm1
Copy link
Contributor

@tmm1 tmm1 commented Aug 28, 2023

running hf trainer with deepspeed/zero2, to train bf16 + load_in_8bit + lora/peft, the following error is observed:

RuntimeError: expected scalar type Float but found BFloat16

poking around a debugger, i saw:

-> breakpoint()                                                                                                            
(Pdb) l                                                                                                                    
113         def forward(self, input: Tensor) -> Tensor:                                                                    
114             try:                                                                                                       
115                 return F.linear(input, self.weight, self.bias)                                                         
116             except Exception:                                                                                          
117                 import pdb                                                                                             
118  ->             breakpoint()                                                                                           
119                                                                                                                        
120         def extra_repr(self) -> str:                                                                                   
121             return 'in_features={}, out_features={}, bias={}'.format(                                                  
122                 self.in_features, self.out_features, self.bias is not None                                             
123             )                                                                                                          
(Pdb) input.dtype                                                                                                          
torch.float32                                                                                                              
(Pdb) self.weight.dtype                                                                                                    
torch.bfloat16                                                                                                             
(Pdb) up                                                                                                                   
> /home/tmm1/micromamba/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py(1501)_call_impl()                 
-> return forward_call(*args, **kwargs)                                                                                    
(Pdb) up                                                                                                                   
> /home/tmm1/micromamba/envs/dev/lib/python3.10/site-packages/peft/tuners/lora.py(1167)forward()                           
-> output = lora_B(lora_A(dropout(x)))                                                                                     
(Pdb) l                                                                                                                    
1162                if requires_conversion:                                                                                
1163                    expected_dtype = result.dtype                                                                      
1164                    if x.dtype != torch.float32:                                                                       
1165                        x = x.float()                                                                                  
1166    
1167 ->             output = lora_B(lora_A(dropout(x)))

fixes axolotl-ai-cloud/axolotl#473

to reproduce:

git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip install -e .
accelerate launch scripts/finetune.py examples/llama-2/lora.yml --deepspeed=deepspeed/zero2.json

cc @TimDettmers

@BenjaminBossan
Copy link
Member

Hey, thanks for investigating this issue.

After merging #807, there is now a merge conflict with your PR. Don't worry, it is easy to resolve. The code that you changed has now moved to this file: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/bnb.py. So just apply the same change there instead and it should be equivalent, thanks.

As to the change itself, after some grepping, I saw that we have the same logic in AdaLoRA as well, I think it needs to be adjusted there too.

Fixes an issue when using bf16 + lora + load_in_8bit, observed with axolotl + deepspeed
@tmm1 tmm1 force-pushed the lora-deepspeed-dtype-mismatch branch from 98f13b7 to 6bde58e Compare August 29, 2023 12:08
@tmm1
Copy link
Contributor Author

tmm1 commented Aug 29, 2023

there is now a merge conflict with your PR

fixed.

I saw that we have the same logic in AdaLoRA as well, I think it needs to be adjusted there too.

done!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, this LGTM.

Let's have another review by @younesbelkada or @pacman100 before merging.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks a lot @tmm1 !

@BenjaminBossan BenjaminBossan merged commit f113af0 into huggingface:main Aug 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RuntimeError: expected scalar type Float but found BFloat16
4 participants