-
Notifications
You must be signed in to change notification settings - Fork 826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add flash attention #235
base: main
Are you sure you want to change the base?
add flash attention #235
Conversation
this makes it seem like its still using the original forward method instead of the one that is patched in? wouldn't it say |
yea it seems like one method is getting patched but not the other. that causes the mismatch and error: vs so the question is why is the forward method not getting patched correctly. |
here is full stack trace:
|
could be possible that qlora patches it again after I patch it? |
this seems suspicious w/ old_forward and new_forward. maybe its grabbing the original method before we patch it with flash-attn |
ah i think you need to call |
try this? 1b56419 |
ok trying |
RuntimeError: FlashAttention only support fp16 and bf16 data type |
seems maybe FlashAttention needs to be modified to support this |
Does this still work? I still get the same error of "RuntimeError: FlashAttention only support fp16 and bf16 data type" |
try this: #221 (comment) replace_attn_with_flash_attn() >load model> model = upcast_layer_for_flash_attention(model, torch.bfloat16) should fix this. |
I try to add flash attention in the same way that fastchat and axolotl do.
However, I get this error message.
I wonder if you have any ideas?