Skip to content

Commit b813836

Browse files
cyang49zucchini-nlp
authored andcommitted
Fix Mamba2 Grouped SSD Support in the torch_forward Path (huggingface#37533)
* Fix mamba2 grouped support in bamba torch path * patch zamba2 and mamba2 * Add a unit test for grouped SSD * add comment for the new unit test * add output_size arg value to repeat_interleave calls * Add comment
1 parent 3bf85b1 commit b813836

File tree

6 files changed

+18
-10
lines changed

6 files changed

+18
-10
lines changed

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -783,8 +783,8 @@ def torch_forward(
783783
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
784784
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
785785
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
786-
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
787-
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
786+
B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
787+
C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
788788
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
789789

790790
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)

src/transformers/models/bamba/modular_bamba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,8 +580,8 @@ def torch_forward(
580580
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
581581
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
582582
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
583-
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
584-
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
583+
B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
584+
C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
585585
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
586586

587587
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)

src/transformers/models/mamba2/modeling_mamba2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,8 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None,
572572
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
573573
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
574574
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
575-
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
576-
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
575+
B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
576+
C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
577577
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
578578

579579
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)

src/transformers/models/zamba2/modeling_zamba2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -860,8 +860,8 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic
860860
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
861861
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
862862
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
863-
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
864-
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
863+
B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
864+
C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
865865
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
866866

867867
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)

src/transformers/models/zamba2/modular_zamba2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,8 +630,8 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic
630630
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
631631
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
632632
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
633-
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
634-
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
633+
B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
634+
C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
635635
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
636636

637637
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)

tests/models/mamba2/test_modeling_mamba2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,14 @@ def test_mamba2_slow_vs_fast_forward(self):
238238
config_and_inputs = self.model_tester.prepare_config_and_inputs()
239239
self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs)
240240

241+
# This test adjusts n_groups to half the original setting and effectively
242+
# creates a grouped SSD configuration in the mamba2 layers
243+
# See https://github.com/huggingface/transformers/pull/37533/
244+
def test_mamba2_slow_vs_fast_forward_grouped(self):
245+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
246+
config_and_inputs[0].n_groups //= 2
247+
self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs)
248+
241249
def test_initialization(self):
242250
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
243251

0 commit comments

Comments
 (0)