diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 6788dcde8c3..414ffe671ae 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -663,10 +663,10 @@ def __init__(self, config: MixtralConfig): self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_states, routing_weights): + def forward(self, hidden_states): current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) current_hidden_states = self.w2(current_hidden_states) - return routing_weights * current_hidden_states + return current_hidden_states MISTRAL_ATTENTION_CLASSES = { @@ -736,7 +736,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here.