Skip to content

Conversation

@cyang49
Copy link
Contributor

@cyang49 cyang49 commented Apr 15, 2025

What does this PR do?

We found a bug in Bamba torch_forward implementation where the Mamba2 grouped SSD heads are incorrectly expanded for computations.

In the original code, it uses torch.repeat but it results in a tile like pattern, e.g. when ngroups=4 and num_heads=16, torch.repeat gives

[W, X, Y, Z] --> [W, X, Y, Z, W, X, Y, Z, W, X, Y, Z, W, X, Y, Z]

instead of the desired

[W, X, Y, Z] -->[W, W, W, W, X, X, X, X, Y, Y, Y, Y, Z, Z, Z, Z]

This causes models using ngroups > 1 and ngroups != num_heads to fail evaluations. We solve it by using torch.repeat_interleave to replace torch.repeat.

The bug was left undetected for a while, perhaps because the cuda_forward path is used by most people, or because Bamba-9B uses ngroups=1.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@fabianlim @ani300 @ArthurZucker

@github-actions github-actions bot marked this pull request as draft April 15, 2025 14:56
@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@cyang49 cyang49 marked this pull request as ready for review April 15, 2025 14:57
@github-actions github-actions bot requested a review from ArthurZucker April 15, 2025 14:57
@vasqu
Copy link
Contributor

vasqu commented Apr 15, 2025

That's a good catch! Can confirm the issue and solution (based on some local tests). For reference, the original mamba repo also points to this in https://github.com/state-spaces/mamba/blob/2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4/mamba_ssm/ops/triton/ssd_chunk_scan.py#L1813-L1814 (when looking at the repeat pattern).

This also affects other mamba2 based models: Could you

  1. Fix the related models (mamba2, zamba2) as well
  2. Add a test

@vasqu
Copy link
Contributor

vasqu commented Apr 15, 2025

mamba2 😄 cc @molbap

@cyang49
Copy link
Contributor Author

cyang49 commented Apr 15, 2025

That's a good catch! Can confirm the issue and solution (based on some local tests). For reference, the original mamba repo also points to this in https://github.com/state-spaces/mamba/blob/2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4/mamba_ssm/ops/triton/ssd_chunk_scan.py#L1813-L1814 (when looking at the repeat pattern).

This also affects other mamba2 based models: Could you

  1. Fix the related models (mamba2, zamba2) as well

Sure, I can patch them as well.

  1. Add a test

I haven't got much experiences writing tests for transformers. Would this be adding a unit test using a ngroups>1 in the model configuration to all of the models affected?

@vasqu
Copy link
Contributor

vasqu commented Apr 15, 2025

I think something along

def test_mamba2_slow_vs_fast_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs)

in mamba2 (by adjusting groups/heads for that test) would be sufficient as both bamba and zamba2 basically copied from mamba2. The models should be refactored tbh to allow modular to copy (but thats not in scope for this PR).

@molbap
Copy link
Contributor

molbap commented Apr 15, 2025

Bamba-9B uses ngroups=1.

Yes, very likely why that was missed haha. Nice catch indeed, following here and yes we should modularize it all, will open an issue!

@cyang49 cyang49 marked this pull request as draft April 15, 2025 19:10
@cyang49 cyang49 changed the title Fix Mamba2 Grouped SSD Support in Bamba torch_forward Path Fix Mamba2 Grouped SSD Support in the torch_forward Path Apr 15, 2025
@cyang49
Copy link
Contributor Author

cyang49 commented Apr 15, 2025

I think something along

def test_mamba2_slow_vs_fast_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs)

in mamba2 (by adjusting groups/heads for that test) would be sufficient as both bamba and zamba2 basically copied from mamba2. The models should be refactored tbh to allow modular to copy (but thats not in scope for this PR).

@vasqu I added a test and use half the default n_groups. I tested this locally on mamba2, and I was able to confirm that before the patch the test would fail and after the patch it would pass

@cyang49 cyang49 marked this pull request as ready for review April 15, 2025 20:39
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small nit. Otherwise LGTM!

Ig slow runs just to be sure nothing is majorly broken? @molbap


def test_mamba2_slow_vs_fast_forward_grouped(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs[0].n_groups //= 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small nit: Could you add a comment / link to this PR so we know in the future why this test was added.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment added

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's nice but I think he meant adding literally

# See https://github.com/huggingface/transformers/pull/37533/

we do that a lot across the library to keep the history :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah.. ok let me do that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@molbap
Copy link
Contributor

molbap commented Apr 16, 2025

run-slow: mamba2

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/mamba2']
quantizations: [] ...

@vasqu
Copy link
Contributor

vasqu commented Apr 16, 2025

Hmm something went wrong on picking up the commit? Not familiar with the new workflow.

@vasqu
Copy link
Contributor

vasqu commented Apr 16, 2025

@molbap can we try slow runs again?

@molbap
Copy link
Contributor

molbap commented Apr 16, 2025

Hey, I'm not sure about the new workflow either haha :D I mostly ran it locally to be sure, and it seems to not break, but would be preferrable to check on our runners. Retrying, cc @ydshieh is this the correct current launch? I thought it was (no need for labels, just ready PR + the message from a maintainer?

run-slow mamba2

@molbap
Copy link
Contributor

molbap commented Apr 16, 2025

run-slow mamba2

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 16, 2025

yes, but i think it should not mixed with other comment text. simply run-slow: mamba2 with our without :

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/mamba2']
quantizations: [] ...

@molbap
Copy link
Contributor

molbap commented Apr 16, 2025

all tests passing on the slow CI, congrats 🙌

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu
Copy link
Contributor

vasqu commented Apr 16, 2025

cc @Cyrilvallez for core maintainer review

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging! Thanks a lot for the fix and for the clean PR @cyang49! 🤗

@Cyrilvallez Cyrilvallez merged commit 4005730 into huggingface:main Apr 16, 2025
12 checks passed
@cyang49 cyang49 deleted the pr_mamba2_groups branch April 16, 2025 20:22
cyr0930 pushed a commit to cyr0930/transformers that referenced this pull request Apr 18, 2025
…#37533)

* Fix mamba2 grouped support in bamba torch path

* patch zamba2 and mamba2

* Add a unit test for grouped SSD

* add comment for the new unit test

* add output_size arg value to repeat_interleave calls

* Add comment
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
…#37533)

* Fix mamba2 grouped support in bamba torch path

* patch zamba2 and mamba2

* Add a unit test for grouped SSD

* add comment for the new unit test

* add output_size arg value to repeat_interleave calls

* Add comment
@vasqu vasqu mentioned this pull request May 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants