Skip to content

Commit

Permalink
Disable Mixtral output_router_logits during inference (#29249)
Browse files Browse the repository at this point in the history
* Set output_router_logits=False in prepare_inputs_for_generation for mixtral

* Add output_router_logits=False to prepare_inputs_for_generation for mixtral

* Fix style
  • Loading branch information
LeonardoEmili authored Feb 28, 2024
1 parent 8a8a0a4 commit 2ce56d3
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,7 +1415,13 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
output_router_logits=False,
**kwargs,
):
# Omit tokens covered by past_key_values
if past_key_values is not None:
Expand Down Expand Up @@ -1467,6 +1473,7 @@ def prepare_inputs_for_generation(
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"output_router_logits": output_router_logits,
}
)
return model_inputs
Expand Down

0 comments on commit 2ce56d3

Please sign in to comment.