Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MPNet ONNX export #691

Merged
merged 4 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ They specify which input generators should be used for the dummy inputs, but rem
- Marian
- MobileBert
- MobileVit
- MPNet
- OwlVit
- Pegasus
- Perceiver
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
return {"input_ids": dynamic_axis, "attention_mask": dynamic_axis}


class MPNetOnnxConfig(DistilBertOnnxConfig):
DEFAULT_ONNX_OPSET = 12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my knowledge, is opset 11 failing for this model? If you want, you can add a comment as to why opset 12 is required.

Copy link
Contributor Author

@jplu jplu Jan 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the error when export with opset 11:

An error occured with the error message: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from here/model.onnx failed:This is an invalid model. Type Error: Type 'tensor(int64)' of input parameter (/encoder/Add_1_output_0) of operator (Min) in node (/encoder/Min) is invalid..

Apparently the Min operator is not compliant with int64 tensors (as far as what I understand from this message)



class RobertaOnnxConfig(DistilBertOnnxConfig):
pass

Expand Down
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,15 @@ class TasksManager:
"image-classification",
onnx="MobileNetV2OnnxConfig",
),
"mpnet": supported_tasks_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx="MPNetOnnxConfig",
),
"mt5": supported_tasks_mapping(
"default",
"default-with-past",
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
"mobilenet-v1": "google/mobilenet_v1_0.75_192",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"mpnet", "hf-internal-testing/tiny-random-mpnet",
"mt5": "lewtun/tiny-random-mt5",
# "owlvit": "google/owlvit-base-patch32",
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
Expand Down