Skip to content

Commit

Permalink
Fix sentence transformers model patching (#1936)
Browse files Browse the repository at this point in the history
fix sentence transformers modeling patching for export
  • Loading branch information
echarlaix authored Jul 2, 2024
1 parent 16d4d72 commit ae591be
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,8 @@ def __init__(
):
super().__init__(config, model, model_kwargs)

self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask
if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral":
self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask

def patched_forward(input_ids, attention_mask):
result = self.orig_forward({"input_ids": input_ids, "attention_mask": attention_mask})
Expand Down

0 comments on commit ae591be

Please sign in to comment.