-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Fix Mamba2 Grouped SSD Support in the torch_forward Path #37533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
That's a good catch! Can confirm the issue and solution (based on some local tests). For reference, the original mamba repo also points to this in https://github.com/state-spaces/mamba/blob/2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4/mamba_ssm/ops/triton/ssd_chunk_scan.py#L1813-L1814 (when looking at the repeat pattern). This also affects other mamba2 based models: Could you
|
|
mamba2 😄 cc @molbap |
Sure, I can patch them as well.
I haven't got much experiences writing tests for transformers. Would this be adding a unit test using a ngroups>1 in the model configuration to all of the models affected? |
|
I think something along transformers/tests/models/mamba2/test_modeling_mamba2.py Lines 237 to 239 in 4cc6b60
in mamba2 (by adjusting groups/heads for that test) would be sufficient as both bamba and zamba2 basically copied from mamba2. The models should be refactored tbh to allow modular to copy (but thats not in scope for this PR). |
Yes, very likely why that was missed haha. Nice catch indeed, following here and yes we should modularize it all, will open an issue! |
2228188 to
29b5b89
Compare
7114e91 to
7dacdec
Compare
@vasqu I added a test and use half the default n_groups. I tested this locally on mamba2, and I was able to confirm that before the patch the test would fail and after the patch it would pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a small nit. Otherwise LGTM!
Ig slow runs just to be sure nothing is majorly broken? @molbap
|
|
||
| def test_mamba2_slow_vs_fast_forward_grouped(self): | ||
| config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||
| config_and_inputs[0].n_groups //= 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a small nit: Could you add a comment / link to this PR so we know in the future why this test was added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's nice but I think he meant adding literally
# See https://github.com/huggingface/transformers/pull/37533/we do that a lot across the library to keep the history :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah.. ok let me do that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
|
run-slow: mamba2 |
|
This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs: models: ['models/mamba2'] |
|
Hmm something went wrong on picking up the commit? Not familiar with the new workflow. |
7dacdec to
7d874fb
Compare
|
@molbap can we try slow runs again? |
|
Hey, I'm not sure about the new workflow either haha :D I mostly ran it locally to be sure, and it seems to not break, but would be preferrable to check on our runners. Retrying, cc @ydshieh is this the correct current launch? I thought it was (no need for labels, just ready PR + the message from a maintainer? run-slow mamba2 |
|
run-slow mamba2 |
|
yes, but i think it should not mixed with other comment text. simply |
|
This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs: models: ['models/mamba2'] |
|
all tests passing on the slow CI, congrats 🙌 |
c794c1b to
c029cf0
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
cc @Cyrilvallez for core maintainer review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merging! Thanks a lot for the fix and for the clean PR @cyang49! 🤗
…#37533) * Fix mamba2 grouped support in bamba torch path * patch zamba2 and mamba2 * Add a unit test for grouped SSD * add comment for the new unit test * add output_size arg value to repeat_interleave calls * Add comment
…#37533) * Fix mamba2 grouped support in bamba torch path * patch zamba2 and mamba2 * Add a unit test for grouped SSD * add comment for the new unit test * add output_size arg value to repeat_interleave calls * Add comment
What does this PR do?
We found a bug in Bamba
torch_forwardimplementation where the Mamba2 grouped SSD heads are incorrectly expanded for computations.In the original code, it uses
torch.repeatbut it results in a tile like pattern, e.g. when ngroups=4 and num_heads=16,torch.repeatgivesinstead of the desired
This causes models using
ngroups > 1andngroups != num_headsto fail evaluations. We solve it by usingtorch.repeat_interleaveto replacetorch.repeat.The bug was left undetected for a while, perhaps because the
cuda_forwardpath is used by most people, or because Bamba-9B uses ngroups=1.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@fabianlim @ani300 @ArthurZucker