-
Notifications
You must be signed in to change notification settings - Fork 484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
flash_attention: support also cross attention. #8427
flash_attention: support also cross attention. #8427
Conversation
840e3fe
to
1bc1fb4
Compare
In case that q and kv have different shapes (cross attention) flash attention with spmd fails since it does not support it.
1bc1fb4
to
9e1e24b
Compare
Rerun the tests when you can. @JackCaoG |
@JackCaoG when using XLA_DISABLE_FUNCTIONALIZATION=1, the flash attention backward tests are failing with error (unrelated to this PR specifically):
|
hmm, not the first time I saw this issue, let me see if I can do anything before I return my laptop... |
I can repo the issue with
so this has nothing to do with segment ID. I check XLA it is from |
there is something more fundemenrtal, it seems like during the backward there is a graph break, I saw
before any mark_step, this is the cause of the crash above.. I tried to revert this pr but I still see the same graph break/ Need to first figure out where this is coming from. |
ok I found an even better repo, if I just mark_step in the fwd it will crash too
with functionizaiton
|
Ok I narrowed it down to this line xla/torch_xla/experimental/custom_kernel.py Line 319 in 591c397
if I disable the functionizaiton and print |
ok I think the issue is in
the IR
next step is to look into where view logic kick in. @miladm I will let you find someone to follow up. |
@JackCaoG Hey! Can you merge this PR? is it fine? |
I am no longer with the team, you should check with @miladm |
In case that q and kv have different shapes (cross attention) flash attention with spmd fails since it does not support it.