Skip to content

Commit

Permalink
Fix mixtral ONNX Exporter Issue. (#29858)
Browse files Browse the repository at this point in the history
* fix mixtral onnx export

* fix qwen model
  • Loading branch information
AdamLouly authored Apr 5, 2024
1 parent 79d62b2 commit d704c0b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
8 changes: 2 additions & 6 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,15 +871,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if top_x.shape[0] == 0:
continue

# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
Expand Down
8 changes: 2 additions & 6 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,15 +843,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if top_x.shape[0] == 0:
continue

# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
Expand Down

0 comments on commit d704c0b

Please sign in to comment.