diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index e9e801bb7167..baa33421d953 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -871,15 +871,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if top_x.shape[0] == 0: continue - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index e921af9232dd..cab2ef5ff7e5 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -843,15 +843,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if top_x.shape[0] == 0: continue - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here.