Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flash-attn + qlora not working with llama models #336

Merged
merged 4 commits into from
Aug 3, 2023

Conversation

tmm1
Copy link
Collaborator

@tmm1 tmm1 commented Aug 3, 2023

fixes this error:

  File "/mnt/ml/axolotl/src/axolotl/flash_attn.py", line 98, in forward                                                    
    output_unpad = flash_attn_varlen_qkvpacked_func(                                                                       
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                       
  File "/mnt/ml/flash-attention/flash_attn/flash_attn_interface.py", line 406, in flash_attn_varlen_qkvpacked_func         
    return FlashAttnVarlenQKVPackedFunc.apply(                                                                             
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                             
  File "/home/tmm1/micromamba/envs/test/lib/python3.11/site-packages/torch/autograd/function.py", line 539, in apply       
    return super().apply(*args, **kwargs)  # type: ignore[misc]                                                            
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                  
  File "/mnt/ml/flash-attention/flash_attn/flash_attn_interface.py", line 123, in forward                                  
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(                                
                                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                
  File "/mnt/ml/flash-attention/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_varlen_forward                
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(                                
                                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                
RuntimeError: FlashAttention only support fp16 and bf16 data type   

@tmm1 tmm1 merged commit 0d2e34f into axolotl-ai-cloud:main Aug 3, 2023
3 checks passed
mkeoliya pushed a commit to mkeoliya/axolotl that referenced this pull request Dec 15, 2023
Fix flash-attn + qlora not working with llama models
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants