You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Run the following code on a GPU with python:3.10 CUDA:12.4 torch:2.4.0
importtorchfromxformersimportopsasxopsfromxformers.ops.fmha.attn_biasimportBlockDiagonalMasktorch.manual_seed(0)
# Make q, k, vB, M, H, K=1, 10, 4, 8Mk=20q=torch.rand((B, M, H, K)).cuda()
k=torch.rand((B, Mk, H, K)).cuda()
v=torch.rand((B, Mk, H, K)).cuda()
# Try to cut off K and V to shorter lengthfornew_Mkin [1, 2, 10, 17, 20]:
new_k=k[:, :new_Mk, :, :]
new_v=v[:, :new_Mk, :, :]
# NOTE: here we intentionally create mask with length (M, Mk)# rather than (M, new_Mk) to trigger the issue.a=BlockDiagonalMask.from_seqlens([M], [Mk])
result=xops.memory_efficient_attention(q, new_k, new_v, attn_bias=a)
print(torch.sum(result).tolist())
While K and V are cut off to shorter length for each iteration, the mask can still work without any error or warning raised for shape unmatch. Does the broadcasting happen implicitly? This is confusing because torch.sdpa would raise an error for shape unmatch.
Although the code can run without any error raised, the results are unexpectedly wrong because we use different length for K and V for each iteration.
I did some investigation without looking into the source code deeply, here is my guess:
There is some broadcasting happening to allow the mask to be applied:
attention_score: (B, H, M, new_Mk)
attention_mask: (B, H, M, Mk)
When Mk != new_Mk, the attention_score is broadcasted to match attention_mask.
The reason why each iteration yields the same result is that the GPU memory of the origin K and V is held across all iterations, then whenever we need to broadcast the new K and V, they would reuse those values in the origin K and V to calculate the attention score. This leads to a result that the new K and V are not really cut off.
To prove my hypothesis, I changed these two lines:
xiangxu-google
changed the title
Unexpected behavior of memory_efficient_attention with BlockDiagonalMask
[Bug] Unexpected behavior of memory_efficient_attention with BlockDiagonalMaskOct 14, 2024
Run the following code on a GPU with python:3.10 CUDA:12.4 torch:2.4.0
will get results:
This raises two questions:
I did some investigation without looking into the source code deeply, here is my guess:
There is some broadcasting happening to allow the mask to be applied:
When
Mk != new_Mk
, theattention_score
is broadcasted to matchattention_mask
.The reason why each iteration yields the same result is that the GPU memory of the origin K and V is held across all iterations, then whenever we need to broadcast the new K and V, they would reuse those values in the origin K and V to calculate the attention score. This leads to a result that the new K and V are not really cut off.
To prove my hypothesis, I changed these two lines:
to prevent the new K and V from reusing the memory, then get the expected results:
So can we add a check to the
attn_bias
argument to explicitly raise an error when its shape unmatches the shape of attention score?The text was updated successfully, but these errors were encountered: