Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba2 torch_forward reduction dimension possibly incorrect? #34817

Open
1 of 4 tasks
HanGuo97 opened this issue Nov 19, 2024 · 5 comments · May be fixed by #34901
Open
1 of 4 tasks

Mamba2 torch_forward reduction dimension possibly incorrect? #34817

HanGuo97 opened this issue Nov 19, 2024 · 5 comments · May be fixed by #34901
Labels

Comments

@HanGuo97
Copy link

System Info

NA

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

NA

Expected behavior

In the torch_forward part of Mamba2, it seems like the reduction dimension should be dim=3 instead of dim=2?

result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)

with dim=3, the output seems to more or less match that of Mamba-2's ssd_minimal implementation, but not with dim=2

@HanGuo97 HanGuo97 added the bug label Nov 19, 2024
@vasqu
Copy link
Contributor

vasqu commented Nov 19, 2024

Yes, that seems correct. Good spotting! I have a fairly extended ramble into this below (ignore if its too much :) cc @molbap

We can also see it based on the einsum notation in the ssd minimal script:
bhzc,bchpn->bzhpn

In this case as we do not use the einsum notation I will notate the same dimension notations before the sum and broadcasted (via none) values as simple 1:
decay chunk: bhzc11
permuted states: bh1cpn
So based on the multiplication before the sum we get bhzcpn and since we wanted shape bzhpn we need to sum along c (on dim=3) and reshape afterwards.

Just a quick idea: I'm not sure if we even have to reshape twice instead of once by reshaping the decay chunks only (not checked):
states: bc1hpn
permuted decay chunks: bczh11
Resulting in bczhpn and finally to bzhpn (after sum on dim=1) - hence we avoid the double permutation and just do it "once".

@vasqu
Copy link
Contributor

vasqu commented Nov 19, 2024

I'm a bit suprised that the following operations after that don't fail. Have you tested your fixed version on a forward?

@HanGuo97
Copy link
Author

As far as I remember, the following operations won't fail because the reductions was on the number of (source) chunks even though it should be on the number of (target) chunk. During training, these two are of the same size.

@vasqu
Copy link
Contributor

vasqu commented Nov 20, 2024

It's been a while but yea that makes sense. Thx for clarifying!

@vasqu
Copy link
Contributor

vasqu commented Nov 20, 2024

A tad late, but I've verified it myself now based on my test and modifying the respective local ssd minimal:

  • Either we need to sum on dim=3
  • Or use transposed decays as I led to before (1 permutation less):
decay_chunk = decay_chunk.transpose(1, 3)
new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants