-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mamba2 torch_forward
reduction dimension possibly incorrect?
#34817
Comments
Yes, that seems correct. Good spotting! I have a fairly extended ramble into this below (ignore if its too much :) cc @molbap We can also see it based on the einsum notation in the ssd minimal script: In this case as we do not use the einsum notation I will notate the same dimension notations before the sum and broadcasted (via none) values as simple 1: Just a quick idea: I'm not sure if we even have to reshape twice instead of once by reshaping the decay chunks only (not checked): |
I'm a bit suprised that the following operations after that don't fail. Have you tested your fixed version on a forward? |
As far as I remember, the following operations won't fail because the reductions was on the number of (source) chunks even though it should be on the number of (target) chunk. During training, these two are of the same size. |
It's been a while but yea that makes sense. Thx for clarifying! |
A tad late, but I've verified it myself now based on my test and modifying the respective local ssd minimal:
decay_chunk = decay_chunk.transpose(1, 3)
new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) |
System Info
NA
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
NA
Expected behavior
In the
torch_forward
part of Mamba2, it seems like the reduction dimension should bedim=3
instead ofdim=2
?transformers/src/transformers/models/mamba2/modeling_mamba2.py
Line 560 in 3033509
with
dim=3
, the output seems to more or less match that of Mamba-2'sssd_minimal
implementation, but not withdim=2
The text was updated successfully, but these errors were encountered: