Skip to content

Commit

Permalink
shape comment
Browse files Browse the repository at this point in the history
  • Loading branch information
alxndrTL committed Jul 22, 2024
1 parent ed5fcc9 commit bec1543
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca

# 3.c perform the recurrence y ← SSM(A, B, C)(x)
if self.use_mambapy and self.training and cache_params is None:
hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, intermediate_size, seq_len, ssm_state_size]
hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size]

scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
scan_output = scan_output + hidden_states * self.D[None, :, None]
Expand Down

0 comments on commit bec1543

Please sign in to comment.