Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nystromformer ONNX export #728

Merged
merged 13 commits into from
Jan 31, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ They specify which input generators should be used for the dummy inputs, but rem
- MobileVit
- MPNet
- OwlVit
- Nystromformer
whr778 marked this conversation as resolved.
Show resolved Hide resolved
- Pegasus
- Perceiver
- PoolFormer
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ class MobileBertOnnxConfig(BertOnnxConfig):
pass


class NystromformerOnnxConfig(BertOnnxConfig):
pass


class XLMOnnxConfig(BertOnnxConfig):
pass

Expand Down
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,15 @@ class TasksManager:
# "zero-shot-object-detection",
# onnx="OwlViTOnnxConfig",
# ),
"nystromformer": supported_tasks_mapping(
"default",
"masked-lm",
"multiple-choice",
"question-answering",
"sequence-classification",
"token-classification",
onnx="NystromformerOnnxConfig",
),
"pegasus": supported_tasks_mapping(
"default",
"default-with-past",
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class ORTConfigManager:
"mbart": "bart",
"mt5": "bart",
"m2m_100": "bart",
"Nystromformer": "bert",
whr778 marked this conversation as resolved.
Show resolved Hide resolved
"roberta": "bert",
"t5": "t5",
"whisper": "whisper",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class NormalizedConfigManager:
"mbart": BartLikeNormalizedTextConfig,
"mt5": T5LikeNormalizedTextConfig,
"m2m_100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
"pegasus": BartLikeNormalizedTextConfig,
"poolformer": NormalizedVisionConfig,
"resnet": NormalizedVisionConfig,
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"mt5": "lewtun/tiny-random-mt5",
# "owlvit": "google/owlvit-base-patch32",
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
"perceiver": {
"hf-internal-testing/tiny-random-language_perceiver": ["masked-lm", "sequence-classification"],
Expand Down Expand Up @@ -157,6 +158,7 @@
"mobilevit": "apple/mobilevit-small",
"mt5": "lewtun/tiny-random-mt5", # Not using google/mt5-small because it takes too much time for testing.
"owlvit": "google/owlvit-base-patch32",
"nystromformer": "uw-madison/nystromformer-1024", # Not using the 2048 or 4096 models
"perceiver": "hf-internal-testing/tiny-random-PerceiverModel", # Not using deepmind/language-perceiver because it takes too much time for testing.
# "rembert": "google/rembert",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
Expand Down
4 changes: 4 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"mt5": "lewtun/tiny-random-mt5",
"nystromformer": "uw-madison/nystromformer-1024", # hf-internal-testing/tiny-random-NystromformerModel
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"resnet": "hf-internal-testing/tiny-random-resnet",
Expand Down Expand Up @@ -884,6 +885,7 @@ class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin):
# "layoutlmv3",
"mbart",
"mobilebert",
"nystromformer",
"roberta",
"roformer",
"squeezebert",
Expand Down Expand Up @@ -1039,6 +1041,7 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin):
# "layoutlmv3",
"mbart",
"mobilebert",
"nystromformer",
# "perceiver",
"roberta",
"roformer",
Expand Down Expand Up @@ -1447,6 +1450,7 @@ class ORTModelForMultipleChoiceIntegrationTest(ORTModelTestMixin):
"flaubert",
"ibert",
"mobilebert",
"nystromformer",
"roberta",
"roformer",
"squeezebert",
Expand Down