Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nystromformer ONNX export #728

Merged
merged 13 commits into from
Jan 31, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ They specify which input generators should be used for the dummy inputs, but rem
- MobileBert
- MobileVit
- MPNet
- Nystromformer
- OwlVit
- Pegasus
- Perceiver
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ class MobileBertOnnxConfig(BertOnnxConfig):
pass


class NystromformerOnnxConfig(BertOnnxConfig):
pass


class XLMOnnxConfig(BertOnnxConfig):
pass

Expand Down
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,15 @@ class TasksManager:
"seq2seq-lm-with-past",
onnx="M2M100OnnxConfig",
),
"nystromformer": supported_tasks_mapping(
"default",
"masked-lm",
"multiple-choice",
"question-answering",
"sequence-classification",
"token-classification",
onnx="NystromformerOnnxConfig",
),
# TODO: owlvit is actually not yet supported in exporters
# "owlvit": supported_tasks_mapping(
# "default",
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class ORTConfigManager:
"mbart": "bart",
"mt5": "bart",
"m2m_100": "bart",
"nystromformer": "bert",
"roberta": "bert",
"t5": "t5",
"whisper": "whisper",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class NormalizedConfigManager:
"mbart": BartLikeNormalizedTextConfig,
"mt5": T5LikeNormalizedTextConfig,
"m2m_100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
"pegasus": BartLikeNormalizedTextConfig,
"poolformer": NormalizedVisionConfig,
"resnet": NormalizedVisionConfig,
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"mt5": "lewtun/tiny-random-mt5",
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
# "owlvit": "google/owlvit-base-patch32",
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
"perceiver": {
Expand Down Expand Up @@ -165,6 +166,7 @@
# "mobilenet_v2": "google/mobilenet_v2_0.35_96",
"mobilevit": "apple/mobilevit-small",
"mt5": "lewtun/tiny-random-mt5", # Not using google/mt5-small because it takes too much time for testing.
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
"owlvit": "google/owlvit-base-patch32",
"perceiver": "hf-internal-testing/tiny-random-PerceiverModel", # Not using deepmind/language-perceiver because it takes too much time for testing.
# "rembert": "google/rembert",
Expand Down
4 changes: 4 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"mt5": "lewtun/tiny-random-mt5",
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"resnet": "hf-internal-testing/tiny-random-resnet",
Expand Down Expand Up @@ -886,6 +887,7 @@ class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin):
# "layoutlmv3",
"mbart",
"mobilebert",
"nystromformer",
"roberta",
"roformer",
"squeezebert",
Expand Down Expand Up @@ -1187,6 +1189,7 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin):
# "layoutlmv3",
"mbart",
"mobilebert",
"nystromformer",
# "perceiver",
"roberta",
"roformer",
Expand Down Expand Up @@ -1595,6 +1598,7 @@ class ORTModelForMultipleChoiceIntegrationTest(ORTModelTestMixin):
"flaubert",
"ibert",
"mobilebert",
"nystromformer",
"roberta",
"roformer",
"squeezebert",
Expand Down