diff --git a/docs/source/exporters/onnx/usage_guides/contribute.mdx b/docs/source/exporters/onnx/usage_guides/contribute.mdx index 6e1d9d3ad9..e18248fc9c 100644 --- a/docs/source/exporters/onnx/usage_guides/contribute.mdx +++ b/docs/source/exporters/onnx/usage_guides/contribute.mdx @@ -129,7 +129,7 @@ classification head, we could use: >>> from transformers import AutoConfig >>> config = AutoConfig.from_pretrained("bert-base-uncased") ->>> onnx_config_for_seq_clf = BertOnnxConfig(config, task="sequence-classification") +>>> onnx_config_for_seq_clf = BertOnnxConfig(config, task="text-classification") >>> print(onnx_config_for_seq_clf.outputs) OrderedDict([('logits', {0: 'batch_size'})]) ``` @@ -157,9 +157,9 @@ For BERT, it looks as follows: ```python "bert": supported_tasks_mapping( "default", - "masked-lm", - "causal-lm", - "sequence-classification", + "fill-mask", + "text-generation", + "text-classification", "multiple-choice", "token-classification", "question-answering", diff --git a/docs/source/exporters/onnx/usage_guides/export_a_model.mdx b/docs/source/exporters/onnx/usage_guides/export_a_model.mdx index 77774c08c4..33d6656f24 100644 --- a/docs/source/exporters/onnx/usage_guides/export_a_model.mdx +++ b/docs/source/exporters/onnx/usage_guides/export_a_model.mdx @@ -87,7 +87,7 @@ Required arguments: output Path indicating the directory where to store generated ONNX model. Optional arguments: - --task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among: ['default', 'masked-lm', 'causal-lm', 'seq2seq-lm', 'sequence-classification', 'token-classification', 'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic-segmentation', 'speech2seq-lm', 'audio-classification', 'audio-frame-classification', 'audio-ctc', 'audio-xvector', 'vision2seq-lm', 'stable-diffusion', 'zero-shot-object-detection']. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder. + --task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among: ['default', 'fill-mask', 'text-generation', 'text2text-generation', 'text-classification', 'token-classification', 'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic-segmentation', 'automatic-speech-recognition', 'audio-classification', 'audio-frame-classification', 'automatic-speech-recognition', 'audio-xvector', 'image-to-text', 'stable-diffusion', 'zero-shot-object-detection']. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder. --monolith Force to export the model as a single ONNX file. By default, the ONNX exporter may break the model in several ONNX files, for example for encoder-decoder models where the encoder should be run only once while the decoder is looped over. --device DEVICE The device to use to do the export. Defaults to "cpu". --opset OPSET If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used. @@ -198,7 +198,7 @@ Models exported through `optimum-cli export onnx` can be used directly in [`~onn When exporting a decoder model used for generation, it can be useful to encapsulate in the exported ONNX the [reuse of past keys and values](https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958/2). This allows to avoid recomputing the same intermediate activations during the generation. -In the ONNX export, the past keys/values are reused by default. This behavior corresponds to `--task seq2seq-lm-with-past`, `--task causal-lm-with-past`, or `--task speech2seq-lm-with-past`. If for any purpose you would like to disable the export with past keys/values reuse, passing explicitly to `optimum-cli export onnx` the task `seq2seq-lm`, `causal-lm` or `speech2seq-lm` is required. +In the ONNX export, the past keys/values are reused by default. This behavior corresponds to `--task text2text-generation-with-past`, `--task text-generation-with-past`, or `--task automatic-speech-recognition-with-past`. If for any purpose you would like to disable the export with past keys/values reuse, passing explicitly to `optimum-cli export onnx` the task `text2text-generation`, `text-generation` or `automatic-speech-recognition` is required. A model exported using past key/values can be reused directly into Optimum's [`~onnxruntime.ORTModel`]: @@ -235,7 +235,7 @@ For each model architecture, you can find the list of supported tasks via the [` >>> distilbert_tasks = list(TasksManager.get_supported_tasks_for_model_type("distilbert", "onnx").keys()) >>> print(distilbert_tasks) -['default', 'masked-lm', 'sequence-classification', 'multiple-choice', 'token-classification', 'question-answering'] +['default', 'fill-mask', 'text-classification', 'multiple-choice', 'token-classification', 'question-answering'] ``` You can then pass one of these tasks to the `--task` argument in the `optimum-cli export onnx` command, as mentioned above. diff --git a/docs/source/exporters/task_manager.mdx b/docs/source/exporters/task_manager.mdx index 6b88ab3a69..6d7f97399c 100644 --- a/docs/source/exporters/task_manager.mdx +++ b/docs/source/exporters/task_manager.mdx @@ -35,7 +35,7 @@ It is possible to know which tasks are supported for a model for a given backend >>> distilbert_tasks = list(TasksManager.get_supported_tasks_for_model_type(model_type, backend).keys()) >>> print(distilbert_tasks) -['default', 'masked-lm', 'sequence-classification', 'multiple-choice', 'token-classification', 'question-answering'] +['default', 'fill-mask', 'text-classification', 'multiple-choice', 'token-classification', 'question-answering'] ``` @@ -44,12 +44,12 @@ It is possible to know which tasks are supported for a model for a given backend | Task | Auto Class | |--------------------------------------|--------------------------------------| -| `causal-lm`, `causal-lm-with-past` | `AutoModelForCausalLM` | -| `default`, `default-with-past` | `AutoModel` | -| `masked-lm` | `AutoModelForMaskedLM` | +| `text-generation`, `text-generation-with-past` | `AutoModelForCausalLM` | +| `feature-extraction`, `feature-extraction-with-past` | `AutoModel` | +| `fill-mask` | `AutoModelForMaskedLM` | | `question-answering` | `AutoModelForQuestionAnswering` | -| `seq2seq-lm`, `seq2seq-lm-with-past` | `AutoModelForSeq2SeqLM` | -| `sequence-classification` | `AutoModelForSequenceClassification` | +| `text2text-generation`, `text2text-generation-with-past` | `AutoModelForSeq2SeqLM` | +| `text-classification` | `AutoModelForSequenceClassification` | | `token-classification` | `AutoModelForTokenClassification` | | `multiple-choice` | `AutoModelForMultipleChoice` | | `image-classification` | `AutoModelForImageClassification` | @@ -57,18 +57,18 @@ It is possible to know which tasks are supported for a model for a given backend | `image-segmentation` | `AutoModelForImageSegmentation` | | `masked-im` | `AutoModelForMaskedImageModeling` | | `semantic-segmentation` | `AutoModelForSemanticSegmentation` | -| `speech2seq-lm` | `AutoModelForSpeechSeq2Seq` | +| `automatic-speech-recognition` | `AutoModelForSpeechSeq2Seq` | ### TensorFlow | Task | Auto Class | |--------------------------------------|----------------------------------------| -| `causal-lm`, `causal-lm-with-past` | `TFAutoModelForCausalLM` | +| `text-generation`, `text-generation-with-past` | `TFAutoModelForCausalLM` | | `default`, `default-with-past` | `TFAutoModel` | -| `masked-lm` | `TFAutoModelForMaskedLM` | +| `fill-mask` | `TFAutoModelForMaskedLM` | | `question-answering` | `TFAutoModelForQuestionAnswering` | -| `seq2seq-lm`, `seq2seq-lm-with-past` | `TFAutoModelForSeq2SeqLM` | -| `sequence-classification` | `TFAutoModelForSequenceClassification` | +| `text2text-generation`, `text2text-generation-with-past` | `TFAutoModelForSeq2SeqLM` | +| `text-classification` | `TFAutoModelForSequenceClassification` | | `token-classification` | `TFAutoModelForTokenClassification` | | `multiple-choice` | `TFAutoModelForMultipleChoice` | | `semantic-segmentation` | `TFAutoModelForSemanticSegmentation` | diff --git a/docs/source/exporters/tflite/usage_guides/export_a_model.mdx b/docs/source/exporters/tflite/usage_guides/export_a_model.mdx index 9fbd28b05a..8666f44543 100644 --- a/docs/source/exporters/tflite/usage_guides/export_a_model.mdx +++ b/docs/source/exporters/tflite/usage_guides/export_a_model.mdx @@ -56,9 +56,9 @@ Required arguments: Optional arguments: --task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on - the model, but are among: ['default', 'masked-lm', 'causal-lm', 'seq2seq-lm', 'sequence-classification', 'token-classification', + the model, but are among: ['default', 'fill-mask', 'text-generation', 'text2text-generation', 'text-classification', 'token-classification', 'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic- - segmentation', 'speech2seq-lm', 'audio-classification', 'audio-frame-classification', 'audio-ctc', 'audio-xvector', 'vision2seq- + segmentation', 'automatic-speech-recognition', 'audio-classification', 'audio-frame-classification', 'automatic-speech-recognition', 'audio-xvector', 'vision2seq- lm', 'stable-diffusion', 'zero-shot-object-detection']. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder. --atol ATOL If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used. diff --git a/docs/source/onnxruntime/usage_guides/trainer.mdx b/docs/source/onnxruntime/usage_guides/trainer.mdx index 49d89897c1..6b466b7257 100644 --- a/docs/source/onnxruntime/usage_guides/trainer.mdx +++ b/docs/source/onnxruntime/usage_guides/trainer.mdx @@ -123,7 +123,7 @@ empowered by ONNX Runtime. Here is an example of how to use `ORTTrainer` compare model=model, args=training_args, train_dataset=train_dataset, -+ feature="sequence-classification", ++ feature="text-classification", ... ) @@ -159,7 +159,7 @@ empowered by ONNX Runtime. Here is an example of how to use `ORTSeq2SeqTrainer` model=model, args=training_args, train_dataset=train_dataset, -+ feature="seq2seq-lm", ++ feature="text2text-generation", ... ) diff --git a/docs/source/quicktour.mdx b/docs/source/quicktour.mdx index 4452915b88..507aee155e 100644 --- a/docs/source/quicktour.mdx +++ b/docs/source/quicktour.mdx @@ -194,7 +194,7 @@ To train transformers with ONNX Runtime's acceleration features, 🤗 Optimum pr model=model, args=training_args, train_dataset=train_dataset, -+ feature="sequence-classification", # The model type to export to ONNX ++ feature="text-classification", # The model type to export to ONNX ... ) diff --git a/examples/onnxruntime/training/language-modeling/run_clm.py b/examples/onnxruntime/training/language-modeling/run_clm.py index fb72737947..2807d3f721 100644 --- a/examples/onnxruntime/training/language-modeling/run_clm.py +++ b/examples/onnxruntime/training/language-modeling/run_clm.py @@ -528,7 +528,7 @@ def compute_metrics(eval_preds): preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available() else None, - feature="causal-lm", + feature="text-generation", ) # Training diff --git a/examples/onnxruntime/training/language-modeling/run_mlm.py b/examples/onnxruntime/training/language-modeling/run_mlm.py index 2f4d7fd3a2..122395a1cd 100755 --- a/examples/onnxruntime/training/language-modeling/run_mlm.py +++ b/examples/onnxruntime/training/language-modeling/run_mlm.py @@ -563,7 +563,7 @@ def compute_metrics(eval_preds): preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available() else None, - feature="masked-lm", + feature="fill-mask", ) # Training diff --git a/examples/onnxruntime/training/summarization/run_summarization.py b/examples/onnxruntime/training/summarization/run_summarization.py index d3f2236be4..d1264489d8 100644 --- a/examples/onnxruntime/training/summarization/run_summarization.py +++ b/examples/onnxruntime/training/summarization/run_summarization.py @@ -652,7 +652,7 @@ def compute_metrics(eval_preds): tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics if training_args.predict_with_generate else None, - feature="seq2seq-lm", + feature="text2text-generation", ) # Training diff --git a/examples/onnxruntime/training/text-classification/run_glue.py b/examples/onnxruntime/training/text-classification/run_glue.py index c1255180c3..7a81a2ff15 100644 --- a/examples/onnxruntime/training/text-classification/run_glue.py +++ b/examples/onnxruntime/training/text-classification/run_glue.py @@ -532,7 +532,7 @@ def compute_metrics(p: EvalPrediction): compute_metrics=compute_metrics, tokenizer=tokenizer, data_collator=data_collator, - feature="sequence-classification", + feature="text-classification", ) # Training diff --git a/examples/onnxruntime/training/translation/run_translation.py b/examples/onnxruntime/training/translation/run_translation.py index ef9f565306..e410454f2f 100644 --- a/examples/onnxruntime/training/translation/run_translation.py +++ b/examples/onnxruntime/training/translation/run_translation.py @@ -575,7 +575,7 @@ def compute_metrics(eval_preds): tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics if training_args.predict_with_generate else None, - feature="seq2seq-lm", + feature="text2text-generation", ) # Training diff --git a/optimum/commands/export/tflite.py b/optimum/commands/export/tflite.py index de453c0fb0..164c442ac0 100644 --- a/optimum/commands/export/tflite.py +++ b/optimum/commands/export/tflite.py @@ -189,7 +189,7 @@ def parse_args_tflite(parser: "ArgumentParser"): default=None, help=( "The name of the column in the dataset containing the main data to preprocess. " - "Only for sequence-classification and token-classification. " + "Only for text-classification and token-classification. " ), ) calibration_dataset_group.add_argument( @@ -198,7 +198,7 @@ def parse_args_tflite(parser: "ArgumentParser"): default=None, help=( "The name of the second column in the dataset containing the main data to preprocess, not always needed. " - "Only for sequence-classification and token-classification. " + "Only for text-classification and token-classification. " ), ) calibration_dataset_group.add_argument( diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index e4a41a66c5..c53d0b3580 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -146,14 +146,7 @@ def main_export( ) original_task = task - # Infer the task - if task == "auto": - try: - task = TasksManager.infer_task_from_model(model_name_or_path) - except KeyError as e: - raise KeyError( - f"The task could not be automatically inferred. Please provide the argument --task with the task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" - ) + task = TasksManager.map_from_synonym(task) framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) @@ -187,6 +180,14 @@ def main_export( torch_dtype=torch_dtype, ) + if task == "auto": + try: + task = TasksManager.infer_task_from_model(model_name_or_path) + except KeyError as e: + raise KeyError( + f"The task could not be automatically inferred. Please provide the argument --task with the task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" + ) + if task != "stable-diffusion" and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type( model.config.model_type.replace("_", "-"), "onnx" ): @@ -215,7 +216,7 @@ def main_export( needs_pad_token_id = ( isinstance(onnx_config, OnnxConfigWithPast) and getattr(model.config, "pad_token_id", None) is None - and task in ["sequence_classification"] + and task in ["text-classification"] ) if needs_pad_token_id: if pad_token_id is not None: @@ -265,7 +266,7 @@ def main_export( model.feature_extractor.save_pretrained(output.joinpath("feature_extractor")) model.save_config(output) else: - if model.config.is_encoder_decoder and task.startswith("causal-lm"): + if model.config.is_encoder_decoder and task.startswith("text-generation"): raise ValueError( f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report" f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model," @@ -275,11 +276,18 @@ def main_export( onnx_files_subpaths = None if ( model.config.is_encoder_decoder - and task.startswith(("seq2seq-lm", "speech2seq-lm", "vision2seq-lm", "default-with-past")) + and task.startswith( + ( + "text2text-generation", + "automatic-speech-recognition", + "image-to-text", + "feature-extraction-with-past", + ) + ) and not monolith ): models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config) - elif task.startswith("causal-lm") and not monolith: + elif task.startswith("text-generation") and not monolith: models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config) else: models_and_onnx_configs = {"model": (model, onnx_config)} diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 9749029b50..875f1bfbcb 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -109,7 +109,7 @@ class OnnxConfig(ExportConfig, ABC): Args: config (`transformers.PretrainedConfig`): The model configuration. - task (`str`, defaults to `"default"`): + task (`str`, defaults to `"feature-extraction"`): The task the model should be exported for. """ @@ -122,10 +122,10 @@ class OnnxConfig(ExportConfig, ABC): _TASK_TO_COMMON_OUTPUTS = { "audio-classification": OrderedDict({"logits": {0: "batch_size"}}), "audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), - "audio-ctc": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "automatic-speech-recognition": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "audio-xvector": OrderedDict({"logits": {0: "batch_size"}, "embeddings": {0: "batch_size"}}), - "causal-lm": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), - "default": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}), + "text-generation": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}), "image-classification": OrderedDict({"logits": {0: "batch_size"}}), # TODO: Is this the same thing as semantic-segmentation? "image-segmentation": OrderedDict( @@ -136,7 +136,7 @@ class OnnxConfig(ExportConfig, ABC): } ), "masked-im": OrderedDict({"logits": {0: "batch_size"}}), - "masked-lm": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "multiple-choice": OrderedDict({"logits": {0: "batch_size", 1: "num_choices"}}), "object-detection": OrderedDict( { @@ -151,11 +151,10 @@ class OnnxConfig(ExportConfig, ABC): } ), "semantic-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}), - "seq2seq-lm": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}), - "sequence-classification": OrderedDict({"logits": {0: "batch_size"}}), - "speech2seq-lm": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "text2text-generation": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}), + "text-classification": OrderedDict({"logits": {0: "batch_size"}}), "token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), - "vision2seq-lm": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "zero-shot-image-classification": OrderedDict( { "logits_per_image": {0: "image_batch_size", 1: "text_batch_size"}, @@ -171,7 +170,7 @@ class OnnxConfig(ExportConfig, ABC): # }), } - def __init__(self, config: "PretrainedConfig", task: str = "default"): + def __init__(self, config: "PretrainedConfig", task: str = "feature-extraction"): if task not in self._TASK_TO_COMMON_OUTPUTS: raise ValueError( f"{task} is not a supported task, supported tasks: {', '.join(self._TASK_TO_COMMON_OUTPUTS.keys())}" @@ -460,7 +459,7 @@ class OnnxConfigWithPast(OnnxConfig, ABC): def __init__( self, config: "PretrainedConfig", - task: str = "default", + task: str = "feature-extraction", use_past: bool = False, use_past_in_inputs: Optional[bool] = None, use_present_in_outputs: Optional[bool] = None, @@ -489,14 +488,14 @@ def __init__( super().__init__(config, task=task) @classmethod - def with_past(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxConfigWithPast": + def with_past(cls, config: "PretrainedConfig", task: str = "feature-extraction") -> "OnnxConfigWithPast": """ Instantiates a [`~optimum.exporters.onnx.OnnxConfig`] with `use_past` attribute set to `True`. Args: config (`transformers.PretrainedConfig`): The underlying model's config to use when exporting to ONNX. - task (`str`, defaults to `"default"`): + task (`str`, defaults to `"feature-extraction"`): The task the model should be exported for. Returns: @@ -509,7 +508,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]: if self.use_past is False: common_outputs = super().outputs # In the other cases, the sequence_length axis is not dynamic, always of length 1 - elif self.task == "default": + elif self.task == "feature-extraction": common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}}) else: common_outputs = OrderedDict({"logits": {0: "batch_size"}}) @@ -656,7 +655,7 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast): def __init__( self, config: "PretrainedConfig", - task: str = "default", + task: str = "feature-extraction", use_past: bool = False, use_past_in_inputs: Optional[bool] = None, use_present_in_outputs: Optional[bool] = None, @@ -675,7 +674,7 @@ def __init__( def override_attributes_for_behavior(self): """Override this to specify custom attribute change for a given behavior.""" if self._behavior is ConfigBehavior.ENCODER: - self.task = "default" + self.task = "feature-extraction" self.use_past_in_inputs = False self.use_present_in_outputs = False if self._behavior is ConfigBehavior.DECODER: @@ -842,13 +841,13 @@ class OnnxConfigWithLoss(OnnxConfig, ABC): """ _tasks_to_extra_inputs = { - "default": {"labels": {0: "batch_size"}}, - "masked-lm": {"labels": {0: "batch_size", 1: "sequence_length"}}, - "causal-lm": {"labels": {0: "batch_size", 1: "sequence_length"}}, - "causal-lm-with-past": {"labels": {0: "batch_size"}}, - "seq2seq-lm": {"labels": {0: "batch_size", 1: "sequence_length"}}, - "seq2seq-lm-with-past": {"labels": {0: "batch_size"}}, - "sequence-classification": {"labels": {0: "batch_size"}}, + "feature-extraction": {"labels": {0: "batch_size"}}, + "fill-mask": {"labels": {0: "batch_size", 1: "sequence_length"}}, + "text-generation": {"labels": {0: "batch_size", 1: "sequence_length"}}, + "text-generation-with-past": {"labels": {0: "batch_size"}}, + "text2text-generation": {"labels": {0: "batch_size", 1: "sequence_length"}}, + "text2text-generation-with-past": {"labels": {0: "batch_size"}}, + "text-classification": {"labels": {0: "batch_size"}}, "token-classification": {"labels": {0: "batch_size", 1: "sequence_length"}}, "multiple-choice": {"labels": {0: "batch_size"}}, "question-answering": { @@ -858,7 +857,7 @@ class OnnxConfigWithLoss(OnnxConfig, ABC): "image-classification": {"labels": {0: "batch_size"}}, } _tasks_to_extra_outputs = { - "default": OrderedDict({"loss": {}}), + "feature-extraction": OrderedDict({"loss": {}}), } DUMMY_EXTRA_INPUT_GENERATOR_CLASSES = (DummyLabelsGenerator,) @@ -882,7 +881,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = self._onnx_config.outputs - extra_outputs = self._tasks_to_extra_outputs["default"] + extra_outputs = self._tasks_to_extra_outputs["feature-extraction"] common_outputs.update(extra_outputs) for key in reversed(extra_outputs.keys()): common_outputs.move_to_end(key, last=False) @@ -938,10 +937,10 @@ def flatten_seq2seq_past_key_values(self, flattened_output, name, idx, t): def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]: flattened_output = {} if name in ["present", "past_key_values"]: - if "causal-lm" in self.task: + if "text-generation" in self.task: for idx, t in enumerate(field): self.flatten_decoder_past_key_values(flattened_output, name, idx, t) - elif "seq2seq-lm" in self.task: + elif "text2text-generation" in self.task: for idx, t in enumerate(field): self.flatten_seq2seq_past_key_values(flattened_output, name, idx, t) else: diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 68ee6dec72..e53e23f3c8 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -279,7 +279,7 @@ class EncoderDecoderOnnxConfig(OnnxSeq2SeqConfigWithPast): def __init__( self, config: "PretrainedConfig", - task: str = "default", + task: str = "feature-extraction", use_past: bool = False, use_past_in_inputs: Optional[bool] = None, use_present_in_outputs: Optional[bool] = None, @@ -300,14 +300,14 @@ def __init__( if self._behavior is not ConfigBehavior.DECODER: encoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor( - exporter="onnx", task="default", model_type=config.encoder.model_type + exporter="onnx", task="feature-extraction", model_type=config.encoder.model_type ) self._encoder_onnx_config = encoder_onnx_config_constructor(config.encoder) self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = self._encoder_onnx_config._normalized_config if self._behavior is not ConfigBehavior.ENCODER: decoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor( - exporter="onnx", task="default", model_type=config.decoder.model_type + exporter="onnx", task="feature-extraction", model_type=config.decoder.model_type ) kwargs = {} if issubclass(decoder_onnx_config_constructor.func, OnnxConfigWithPast): diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 943a43ee52..e335542c7a 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -349,7 +349,7 @@ def __init__( def generate(self, input_name: str, framework: str = "pt"): int_tensor = super().generate(input_name, framework=framework) # This inserts EOS_TOKEN_ID at random locations along the sequence length dimension. - if self.force_eos_token_id_presence and "input_ids" in input_name and self.task == "sequence-classification": + if self.force_eos_token_id_presence and "input_ids" in input_name and self.task == "text-classification": for idx in range(self.batch_size): if self.eos_token_id in int_tensor[idx]: continue @@ -363,7 +363,7 @@ class BartOnnxConfig(TextSeq2SeqOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( encoder_num_layers="encoder_layers", decoder_num_layers="decoder_layers", - num_layers="decoder_layers", # Used for the causal-lm task past key values input generation. + num_layers="decoder_layers", # Used for the text-generation task past key values input generation. encoder_num_attention_heads="encoder_attention_heads", decoder_num_attention_heads="decoder_attention_heads", eos_token_id="eos_token_id", @@ -371,12 +371,12 @@ class BartOnnxConfig(TextSeq2SeqOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( BartDummyTextInputGenerator, { - "default": DummySeq2SeqDecoderTextInputGenerator, - "causal-lm": DummyDecoderTextInputGenerator, + "feature-extraction": DummySeq2SeqDecoderTextInputGenerator, + "text-generation": DummyDecoderTextInputGenerator, }, { - "default": DummySeq2SeqPastKeyValuesGenerator, - "causal-lm": DummyPastKeyValuesGenerator, + "feature-extraction": DummySeq2SeqPastKeyValuesGenerator, + "text-generation": DummyPastKeyValuesGenerator, }, ) @@ -384,11 +384,11 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( self.task, self._normalized_config, **kwargs ) - task = "default" if self.task != "causal-lm" else "causal-lm" + task = "feature-extraction" if self.task != "text-generation" else "text-generation" dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1][task]( self.task, self._normalized_config, **kwargs ) - if self.task != "causal-lm": + if self.task != "text-generation": kwargs["encoder_sequence_length"] = dummy_text_input_generator.sequence_length dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2][task]( @@ -440,16 +440,16 @@ def inputs_for_other_tasks(self): @property def inputs(self) -> Dict[str, Dict[int, str]]: inputs_properties = { - "default": self.inputs_for_default_and_seq2seq_lm, - "seq2seq-lm": self.inputs_for_default_and_seq2seq_lm, - "causal-lm": self.inputs_for_causal_lm, + "feature-extraction": self.inputs_for_default_and_seq2seq_lm, + "text2text-generation": self.inputs_for_default_and_seq2seq_lm, + "text-generation": self.inputs_for_causal_lm, "other": self.inputs_for_other_tasks, } return inputs_properties.get(self.task, inputs_properties["other"]) @property def outputs(self) -> Dict[str, Dict[int, str]]: - if self.task in ["default", "seq2seq-lm"]: + if self.task in ["feature-extraction", "text2text-generation"]: common_outputs = super().outputs else: common_outputs = super(OnnxConfigWithPast, self).outputs @@ -463,8 +463,8 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs def generate_dummy_inputs(self, framework: str = "pt", **kwargs): - # This will handle the attention mask padding when Bart is used for causal-lm. - if self.task == "causal-lm": + # This will handle the attention mask padding when Bart is used for text-generation. + if self.task == "text-generation": self.PAD_ATTENTION_MASK_TO_PAST = True dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) @@ -474,7 +474,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): return dummy_inputs def flatten_past_key_values(self, flattened_output, name, idx, t): - if self.task in ["default", "seq2seq-lm"]: + if self.task in ["feature-extraction", "text2text-generation"]: flattened_output = super().flatten_past_key_values(flattened_output, name, idx, t) else: flattened_output = super(OnnxSeq2SeqConfigWithPast, self).flatten_past_key_values( @@ -794,7 +794,7 @@ class LayoutLMv3OnnxConfig(TextAndVisionOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: - if self.task in ["sequence-classification", "question-answering"]: + if self.task in ["text-classification", "question-answering"]: pixel_values_dynamic_axes = {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"} else: pixel_values_dynamic_axes = {0: "batch_size", 1: "num_channels"} @@ -833,14 +833,14 @@ class PerceiverOnnxConfig(TextAndVisionOnnxConfig): PerceiverDummyInputGenerator, ) + TextAndVisionOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES - def __init__(self, config: "PretrainedConfig", task: str = "default"): + def __init__(self, config: "PretrainedConfig", task: str = "feature-extraction"): super().__init__(config, task=task) self.is_generating_dummy_inputs = False @property def inputs_name(self): if self.is_generating_dummy_inputs: - if self.task in ["masked-lm", "sequence-classification"]: + if self.task in ["fill-mask", "text-classification"]: return "input_ids" else: return "pixel_values" @@ -1046,7 +1046,7 @@ class VisionEncoderDecoderOnnxConfig(EncoderDecoderOnnxConfig): def __init__( self, config: "PretrainedConfig", - task: str = "default", + task: str = "feature-extraction", use_past: bool = False, use_past_in_inputs: Optional[bool] = None, use_present_in_outputs: Optional[bool] = None, diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index ddf83acd35..03777fddf6 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -169,7 +169,7 @@ def get_stable_diffusion_models_for_export( # Text encoder text_encoder_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.text_encoder, exporter="onnx", task="default" + model=pipeline.text_encoder, exporter="onnx", task="feature-extraction" ) text_encoder_onnx_config = text_encoder_config_constructor(pipeline.text_encoder.config) models_for_export["text_encoder"] = (pipeline.text_encoder, text_encoder_onnx_config) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 4922f80900..535aa78ac3 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union import huggingface_hub -from transformers import PretrainedConfig, is_tf_available, is_torch_available +from transformers import AutoConfig, PretrainedConfig, is_tf_available, is_torch_available from transformers.utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging from ..utils.import_utils import is_onnx_available @@ -118,6 +118,21 @@ def supported_tasks_mapping( return mapping +def get_automodels_to_tasks(tasks_to_automodel: Dict[str, Union[str, Tuple[str]]]) -> Dict[str, str]: + """ + Reverses tasks_to_automodel while flattening the case where the same task maps to several + auto classes (e.g. automatic-speech-recognition). + """ + automodels_to_tasks = {} + for task, automodels in tasks_to_automodel.items(): + if isinstance(automodels, str): + automodels_to_tasks[automodels] = task + else: + automodels_to_tasks.update({automodel_name: task for automodel_name in automodels}) + + return automodels_to_tasks + + class TasksManager: """ Handles the `task name -> model class` and `architecture -> configuration` mappings. @@ -126,12 +141,15 @@ class TasksManager: _TASKS_TO_AUTOMODELS = {} _TASKS_TO_TF_AUTOMODELS = {} if is_torch_available(): + # Refer to https://huggingface.co/datasets/huggingface/transformers-metadata/blob/main/pipeline_tags.json + # In case the same task (pipeline tag) may map to several loading classes, we use a tuple and the + # auto-class _model_mapping to determine the right one. _TASKS_TO_AUTOMODELS = { - "default": "AutoModel", - "masked-lm": "AutoModelForMaskedLM", - "causal-lm": "AutoModelForCausalLM", - "seq2seq-lm": "AutoModelForSeq2SeqLM", - "sequence-classification": "AutoModelForSequenceClassification", + "feature-extraction": "AutoModel", + "fill-mask": "AutoModelForMaskedLM", + "text-generation": "AutoModelForCausalLM", + "text2text-generation": "AutoModelForSeq2SeqLM", + "text-classification": "AutoModelForSequenceClassification", "token-classification": "AutoModelForTokenClassification", "multiple-choice": "AutoModelForMultipleChoice", "object-detection": "AutoModelForObjectDetection", @@ -140,24 +158,23 @@ class TasksManager: "image-segmentation": "AutoModelForImageSegmentation", "masked-im": "AutoModelForMaskedImageModeling", "semantic-segmentation": "AutoModelForSemanticSegmentation", - "speech2seq-lm": "AutoModelForSpeechSeq2Seq", + "automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"), "audio-classification": "AutoModelForAudioClassification", "audio-frame-classification": "AutoModelForAudioFrameClassification", - "audio-ctc": "AutoModelForCTC", "audio-xvector": "AutoModelForAudioXVector", - "vision2seq-lm": "AutoModelForVision2Seq", + "image-to-text": "AutoModelForVision2Seq", "stable-diffusion": "StableDiffusionPipeline", "zero-shot-image-classification": "AutoModelForZeroShotImageClassification", "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", } if is_tf_available(): _TASKS_TO_TF_AUTOMODELS = { - "default": "TFAutoModel", - "masked-lm": "TFAutoModelForMaskedLM", - "causal-lm": "TFAutoModelForCausalLM", + "feature-extraction": "TFAutoModel", + "fill-mask": "TFAutoModelForMaskedLM", + "text-generation": "TFAutoModelForCausalLM", "image-classification": "TFAutoModelForImageClassification", - "seq2seq-lm": "TFAutoModelForSeq2SeqLM", - "sequence-classification": "TFAutoModelForSequenceClassification", + "text2text-generation": "TFAutoModelForSeq2SeqLM", + "text-classification": "TFAutoModelForSequenceClassification", "token-classification": "TFAutoModelForTokenClassification", "multiple-choice": "TFAutoModelForMultipleChoice", "object-detection": "TFAutoModelForObjectDetection", @@ -165,25 +182,45 @@ class TasksManager: "image-segmentation": "TFAutoModelForImageSegmentation", "masked-im": "TFAutoModelForMaskedImageModeling", "semantic-segmentation": "TFAutoModelForSemanticSegmentation", - "speech2seq-lm": "TFAutoModelForSpeechSeq2Seq", + "automatic-speech-recognition": "TFAutoModelForSpeechSeq2Seq", "audio-classification": "TFAutoModelForAudioClassification", "audio-frame-classification": "TFAutoModelForAudioFrameClassification", - "audio-ctc": "TFAutoModelForCTC", "audio-xvector": "TFAutoModelForAudioXVector", - "vision2seq-lm": "TFAutoModelForVision2Seq", + "image-to-text": "TFAutoModelForVision2Seq", "zero-shot-image-classification": "TFAutoModelForZeroShotImageClassification", "zero-shot-object-detection": "TFAutoModelForZeroShotObjectDetection", } - _AUTOMODELS_TO_TASKS = {cls_name: task for task, cls_name in _TASKS_TO_AUTOMODELS.items()} - _TF_AUTOMODELS_TO_TASKS = {cls_name: task for task, cls_name in _TASKS_TO_TF_AUTOMODELS.items()} + _SYNONYM_TASK_MAP = { + "sequence-classification": "text-classification", + "causal-lm": "text-generation", + "causal-lm-with-past": "text-generation-with-past", + "seq2seq-lm": "text2text-generation", + "seq2seq-lm-with-past": "text2text-generation-with-past", + "speech2seq-lm": "automatic-speech-recognition", + "speech2seq-lm-with-past": "automatic-speech-recognition-with-past", + "masked-lm": "fill-mask", + "vision2seq-lm": "image-to-text", + "default": "feature-extraction", + "default-with-past": "feature-extraction-with-past", + "audio-ctc": "automatic-speech-recognition", + } + + # Reverse dictionaries str -> str, where several automodels may map to the same task + _AUTOMODELS_TO_TASKS = get_automodels_to_tasks(_TASKS_TO_AUTOMODELS) + _TF_AUTOMODELS_TO_TASKS = get_automodels_to_tasks(_TASKS_TO_TF_AUTOMODELS) + + _CUSTOM_CLASSES = { + ("pt", "pix2struct", "image-to-text"): ("transformers", "Pix2StructForConditionalGeneration"), + ("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"), + } _TASKS_TO_LIBRARY = { - "default": "transformers", - "masked-lm": "transformers", - "causal-lm": "transformers", - "seq2seq-lm": "transformers", - "sequence-classification": "transformers", + "feature-extraction": "transformers", + "fill-mask": "transformers", + "text-generation": "transformers", + "text2text-generation": "transformers", + "text-classification": "transformers", "token-classification": "transformers", "multiple-choice": "transformers", "object-detection": "transformers", @@ -192,29 +229,28 @@ class TasksManager: "image-segmentation": "transformers", "masked-im": "transformers", "semantic-segmentation": "transformers", - "speech2seq-lm": "transformers", - "audio-ctc": "transformers", + "automatic-speech-recognition": "transformers", "audio-classification": "transformers", "audio-frame-classification": "transformers", "audio-xvector": "transformers", - "vision2seq-lm": "transformers", + "image-to-text": "transformers", "stable-diffusion": "diffusers", "zero-shot-image-classification": "transformers", "zero-shot-object-detection": "transformers", } - # TODO: some models here support causal-lm export but are not supported in ORTModelForCausalLM + # TODO: some models here support text-generation export but are not supported in ORTModelForCausalLM # Set of model topologies we support associated to the tasks supported by each topology and the factory _SUPPORTED_MODEL_TYPE = { "audio-spectrogram-transformer": supported_tasks_mapping( - "default", + "feature-extraction", "audio-classification", onnx="ASTOnnxConfig", ), "albert": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -222,24 +258,24 @@ class TasksManager: tflite="AlbertTFLiteConfig", ), "bart": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", - "sequence-classification", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + "text-classification", "question-answering", onnx="BartOnnxConfig", ), # BEiT cannot be used with the masked image modeling autoclass, so this task is excluded here - "beit": supported_tasks_mapping("default", "image-classification", onnx="BeitOnnxConfig"), + "beit": supported_tasks_mapping("feature-extraction", "image-classification", onnx="BeitOnnxConfig"), "bert": supported_tasks_mapping( - "default", - "masked-lm", - # the logic for causal-lm is not supported for BERT - # "causal-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + # the logic for text-generation is not supported for BERT + # "text-generation", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -248,11 +284,11 @@ class TasksManager: ), # For big-bird and bigbird-pegasus being unsupported, refer to model_configs.py # "big-bird": supported_tasks_mapping( - # "default", - # "masked-lm", - # # the logic for causal-lm is not supported for big-bird - # # "causal-lm", - # "sequence-classification", + # "feature-extraction", + # "fill-mask", + # # the logic for text-generation is not supported for big-bird + # # "text-generation", + # "text-classification", # "multiple-choice", # "token-classification", # "question-answering", @@ -261,49 +297,49 @@ class TasksManager: # # tflite="BigBirdTFLiteConfig", # ), # "bigbird-pegasus": supported_tasks_mapping( - # "default", - # "default-with-past", - # "causal-lm", - # "causal-lm-with-past", - # "seq2seq-lm", - # "seq2seq-lm-with-past", - # "sequence-classification", + # "feature-extraction", + # "feature-extraction-with-past", + # "text-generation", + # "text-generation-with-past", + # "text2text-generation", + # "text2text-generation-with-past", + # "text-classification", # "question-answering", # onnx="BigBirdPegasusOnnxConfig", # ), "blenderbot": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", onnx="BlenderbotOnnxConfig", ), "blenderbot-small": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", onnx="BlenderbotSmallOnnxConfig", ), "bloom": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", - "sequence-classification", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", "token-classification", onnx="BloomOnnxConfig", ), "camembert": supported_tasks_mapping( - "default", - "masked-lm", - # the logic for causal-lm is not supported for camembert - # "causal-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + # the logic for text-generation is not supported for camembert + # "text-generation", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -311,25 +347,25 @@ class TasksManager: tflite="CamembertTFLiteConfig", ), "clip": supported_tasks_mapping( - "default", + "feature-extraction", "zero-shot-image-classification", onnx="CLIPOnnxConfig", ), "clip-text-model": supported_tasks_mapping( - "default", + "feature-extraction", onnx="CLIPTextOnnxConfig", ), "codegen": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", onnx="CodeGenOnnxConfig", ), "convbert": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -337,64 +373,66 @@ class TasksManager: tflite="ConvBertTFLiteConfig", ), "convnext": supported_tasks_mapping( - "default", + "feature-extraction", "image-classification", onnx="ConvNextOnnxConfig", ), "data2vec-text": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", "multiple-choice", "token-classification", "question-answering", onnx="Data2VecTextOnnxConfig", ), "data2vec-vision": supported_tasks_mapping( - "default", + "feature-extraction", "image-classification", # ONNX doesn't support `adaptive_avg_pool2d` yet # "semantic-segmentation", onnx="Data2VecVisionOnnxConfig", ), "data2vec-audio": supported_tasks_mapping( - "default", - "audio-ctc", + "feature-extraction", + "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", onnx="Data2VecAudioOnnxConfig", ), "deberta": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", "token-classification", "question-answering", onnx="DebertaOnnxConfig", tflite="DebertaTFLiteConfig", ), "deberta-v2": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", ("multiple-choice", ("onnx",)), "token-classification", "question-answering", onnx="DebertaV2OnnxConfig", tflite="DebertaV2TFLiteConfig", ), - "deit": supported_tasks_mapping("default", "image-classification", "masked-im", onnx="DeiTOnnxConfig"), + "deit": supported_tasks_mapping( + "feature-extraction", "image-classification", "masked-im", onnx="DeiTOnnxConfig" + ), "detr": supported_tasks_mapping( - "default", + "feature-extraction", "object-detection", "image-segmentation", onnx="DetrOnnxConfig", ), "distilbert": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -402,15 +440,15 @@ class TasksManager: tflite="DistilBertTFLiteConfig", ), "donut-swin": supported_tasks_mapping( - "default", + "feature-extraction", onnx="DonutSwinOnnxConfig", ), "electra": supported_tasks_mapping( - "default", - "masked-lm", - # the logic for causal-lm is not supported for electra - # "causal-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + # the logic for text-generation is not supported for electra + # "text-generation", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -418,9 +456,9 @@ class TasksManager: tflite="ElectraTFLiteConfig", ), "flaubert": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -428,130 +466,130 @@ class TasksManager: tflite="FlaubertTFLiteConfig", ), "gpt2": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", - "sequence-classification", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", "token-classification", onnx="GPT2OnnxConfig", ), "gptj": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", "question-answering", - "sequence-classification", + "text-classification", onnx="GPTJOnnxConfig", ), "gpt-neo": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", - "sequence-classification", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", onnx="GPTNeoOnnxConfig", ), "gpt-neox": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", onnx="GPTNeoXOnnxConfig", ), "groupvit": supported_tasks_mapping( - "default", + "feature-extraction", onnx="GroupViTOnnxConfig", ), "hubert": supported_tasks_mapping( - "default", - "audio-ctc", + "feature-extraction", + "automatic-speech-recognition", "audio-classification", onnx="HubertOnnxConfig", ), "ibert": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", "multiple-choice", "token-classification", "question-answering", onnx="IBertOnnxConfig", ), "imagegpt": supported_tasks_mapping( - "default", + "feature-extraction", "image-classification", onnx="ImageGPTOnnxConfig", ), "layoutlm": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", "token-classification", onnx="LayoutLMOnnxConfig", ), # "layoutlmv2": supported_tasks_mapping( - # "default", + # "feature-extraction", # "question-answering", - # "sequence-classification", + # "text-classification", # "token-classification", # onnx="LayoutLMv2OnnxConfig", # ), "layoutlmv3": supported_tasks_mapping( - "default", + "feature-extraction", "question-answering", - "sequence-classification", + "text-classification", "token-classification", onnx="LayoutLMv3OnnxConfig", ), - "levit": supported_tasks_mapping("default", "image-classification", onnx="LevitOnnxConfig"), + "levit": supported_tasks_mapping("feature-extraction", "image-classification", onnx="LevitOnnxConfig"), "longt5": supported_tasks_mapping( - "default", - "default-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text2text-generation", + "text2text-generation-with-past", onnx="LongT5OnnxConfig", ), # "longformer": supported_tasks_mapping( - # "default", - # "masked-lm", + # "feature-extraction", + # "fill-mask", # "multiple-choice", # "question-answering", - # "sequence-classification", + # "text-classification", # "token-classification", # onnx_config_cls="models.longformer.LongformerOnnxConfig", # ), "marian": supported_tasks_mapping( - "default", - "default-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", - "causal-lm", - "causal-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text2text-generation", + "text2text-generation-with-past", + "text-generation", + "text-generation-with-past", onnx="MarianOnnxConfig", ), "mbart": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", - "sequence-classification", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + "text-classification", "question-answering", onnx="MBartOnnxConfig", ), # TODO: enable once the missing operator is supported. # "mctct": supported_tasks_mapping( - # "default", - # "audio-ctc", + # "feature-extraction", + # "automatic-speech-recognition", # onnx="MCTCTOnnxConfig", # ), "mobilebert": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -559,24 +597,24 @@ class TasksManager: tflite="MobileBertTFLiteConfig", ), "mobilevit": supported_tasks_mapping( - "default", + "feature-extraction", "image-classification", onnx="MobileViTOnnxConfig", ), "mobilenet-v1": supported_tasks_mapping( - "default", + "feature-extraction", "image-classification", onnx="MobileNetV1OnnxConfig", ), "mobilenet-v2": supported_tasks_mapping( - "default", + "feature-extraction", "image-classification", onnx="MobileNetV2OnnxConfig", ), "mpnet": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -584,77 +622,77 @@ class TasksManager: tflite="MPNetTFLiteConfig", ), "mt5": supported_tasks_mapping( - "default", - "default-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text2text-generation", + "text2text-generation-with-past", onnx="MT5OnnxConfig", ), "m2m-100": supported_tasks_mapping( - "default", - "default-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text2text-generation", + "text2text-generation-with-past", onnx="M2M100OnnxConfig", ), "nystromformer": supported_tasks_mapping( - "default", - "masked-lm", + "feature-extraction", + "fill-mask", "multiple-choice", "question-answering", - "sequence-classification", + "text-classification", "token-classification", onnx="NystromformerOnnxConfig", ), # TODO: owlvit cannot be exported yet, check model_config.py to know why. # "owlvit": supported_tasks_mapping( - # "default", + # "feature-extraction", # "zero-shot-object-detection", # onnx="OwlViTOnnxConfig", # ), "opt": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", "question-answering", - "sequence-classification", + "text-classification", onnx="OPTOnnxConfig", ), "pegasus": supported_tasks_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", onnx="PegasusOnnxConfig", ), "perceiver": supported_tasks_mapping( - "masked-lm", + "fill-mask", "image-classification", - "sequence-classification", + "text-classification", onnx="PerceiverOnnxConfig", ), "poolformer": supported_tasks_mapping( - "default", + "feature-extraction", "image-classification", onnx="PoolFormerOnnxConfig", ), "regnet": supported_tasks_mapping( - "default", + "feature-extraction", "image-classification", onnx="RegNetOnnxConfig", ), "resnet": supported_tasks_mapping( - "default", "image-classification", onnx="ResNetOnnxConfig", tflite="ResNetTFLiteConfig" + "feature-extraction", "image-classification", onnx="ResNetOnnxConfig", tflite="ResNetTFLiteConfig" ), "roberta": supported_tasks_mapping( - "default", - "masked-lm", - # the logic for causal-lm is not supported for roberta - # "causal-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + # the logic for text-generation is not supported for roberta + # "text-generation", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -662,11 +700,11 @@ class TasksManager: tflite="RobertaTFLiteConfig", ), "roformer": supported_tasks_mapping( - "default", - "masked-lm", - # the logic for causal-lm is not supported for roformer - # "causal-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + # the logic for text-generation is not supported for roformer + # "text-generation", + "text-classification", "token-classification", "multiple-choice", "question-answering", @@ -675,62 +713,62 @@ class TasksManager: tflite="RoFormerTFLiteConfig", ), "segformer": supported_tasks_mapping( - "default", + "feature-extraction", "image-classification", "semantic-segmentation", onnx="SegformerOnnxConfig", ), "sew": supported_tasks_mapping( - "default", - "audio-ctc", + "feature-extraction", + "automatic-speech-recognition", "audio-classification", onnx="SEWOnnxConfig", ), "sew-d": supported_tasks_mapping( - "default", - "audio-ctc", + "feature-extraction", + "automatic-speech-recognition", "audio-classification", onnx="SEWDOnnxConfig", ), "speech-to-text": supported_tasks_mapping( - "default", - "default-with-past", - "speech2seq-lm", - "speech2seq-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "automatic-speech-recognition", + "automatic-speech-recognition-with-past", onnx="Speech2TextOnnxConfig", ), "splinter": supported_tasks_mapping( - "default", + "feature-extraction", "question-answering", onnx="SplinterOnnxConfig", ), "squeezebert": supported_tasks_mapping( - "default", - "masked-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + "text-classification", "multiple-choice", "token-classification", "question-answering", onnx="SqueezeBertOnnxConfig", ), "swin": supported_tasks_mapping( - "default", + "feature-extraction", "image-classification", "masked-im", onnx="SwinOnnxConfig", ), "t5": supported_tasks_mapping( - "default", - "default-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "text2text-generation", + "text2text-generation-with-past", onnx="T5OnnxConfig", ), "trocr": supported_tasks_mapping( - "default", - "default-with-past", - "vision2seq-lm", - "vision2seq-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "image-to-text", + "image-to-text-with-past", onnx="TrOCROnnxConfig", ), "unet": supported_tasks_mapping( @@ -738,14 +776,14 @@ class TasksManager: onnx="UNetOnnxConfig", ), "unispeech": supported_tasks_mapping( - "default", - "audio-ctc", + "feature-extraction", + "automatic-speech-recognition", "audio-classification", onnx="UniSpeechOnnxConfig", ), "unispeech-sat": supported_tasks_mapping( - "default", - "audio-ctc", + "feature-extraction", + "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", @@ -760,48 +798,50 @@ class TasksManager: onnx="VaeDecoderOnnxConfig", ), "vision-encoder-decoder": supported_tasks_mapping( - "vision2seq-lm", - "vision2seq-lm-with-past", + "image-to-text", + "image-to-text-with-past", onnx="VisionEncoderDecoderOnnxConfig", ), - "vit": supported_tasks_mapping("default", "image-classification", "masked-im", onnx="ViTOnnxConfig"), + "vit": supported_tasks_mapping( + "feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig" + ), "wavlm": supported_tasks_mapping( - "default", - "audio-ctc", + "feature-extraction", + "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", onnx="WavLMOnnxConfig", ), "wav2vec2": supported_tasks_mapping( - "default", - "audio-ctc", + "feature-extraction", + "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", onnx="Wav2Vec2OnnxConfig", ), "wav2vec2-conformer": supported_tasks_mapping( - "default", - "audio-ctc", + "feature-extraction", + "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", onnx="Wav2Vec2ConformerOnnxConfig", ), "whisper": supported_tasks_mapping( - "default", - "default-with-past", - "speech2seq-lm", - "speech2seq-lm-with-past", + "feature-extraction", + "feature-extraction-with-past", + "automatic-speech-recognition", + "automatic-speech-recognition-with-past", onnx="WhisperOnnxConfig", ), "xlm": supported_tasks_mapping( - "default", - "masked-lm", - # the logic for causal-lm is not supported for xlm - # "causal-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + # the logic for text-generation is not supported for xlm + # "text-generation", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -809,11 +849,11 @@ class TasksManager: tflite="XLMTFLiteConfig", ), "xlm-roberta": supported_tasks_mapping( - "default", - "masked-lm", - # the logic for causal-lm is not supported for xlm-roberta - # "causal-lm", - "sequence-classification", + "feature-extraction", + "fill-mask", + # the logic for text-generation is not supported for xlm-roberta + # "text-generation", + "text-classification", "multiple-choice", "token-classification", "question-answering", @@ -821,7 +861,7 @@ class TasksManager: tflite="XLMRobertaTFLiteConfig", ), "yolos": supported_tasks_mapping( - "default", + "feature-extraction", "object-detection", onnx="YolosOnnxConfig", ), @@ -850,7 +890,7 @@ def create_register( ```python >>> register_for_new_backend = create_register("new-backend") - >>> @register_for_new_backend("bert", "sequence-classification", "token-classification") + >>> @register_for_new_backend("bert", "text-classification", "token-classification") >>> class BertNewBackendConfig(NewBackendConfig): >>> pass ``` @@ -925,8 +965,10 @@ def get_supported_model_type_for_task(task: str, exporter: str) -> List[str]: ] @staticmethod - def format_task(task: str) -> str: - return task.replace("-with-past", "") + def map_from_synonym(task: str) -> str: + if task in TasksManager._SYNONYM_TASK_MAP: + task = TasksManager._SYNONYM_TASK_MAP[task] + return task @staticmethod def _validate_framework_choice(framework: str): @@ -942,33 +984,88 @@ def _validate_framework_choice(framework: str): raise RuntimeError("Cannot export model using TensorFlow because no TensorFlow package was found.") @staticmethod - def get_model_class_for_task(task: str, framework: str = "pt") -> Type: + def get_model_class_for_task( + task: str, framework: str = "pt", model_type: Optional[str] = None, model_class_name: Optional[str] = None + ) -> Type: """ Attempts to retrieve an AutoModel class from a task name. Args: task (`str`): The task required. - framework (`str`, *optional*, defaults to `"pt"`): + framework (`str`, defaults to `"pt"`): The framework to use for the export. + model_type (`Optional[str]`, defaults to `None`): + The model type to retrieve the model class for. Some architectures need a custom class to be loaded, + and can not be loaded from auto class. + model_class_name (`Optional[str]`, defaults to `None`): + A model class name, allowing to override the default class that would be detected for the task. This + parameter is useful for example for "automatic-speech-recognition", that may map to + AutoModelForSpeechSeq2Seq or to AutoModelForCTC. Returns: The AutoModel class corresponding to the task. """ - task = TasksManager.format_task(task) + task = task.replace("-with-past", "") + task = TasksManager.map_from_synonym(task) + TasksManager._validate_framework_choice(framework) - if framework == "pt": - tasks_to_automodel = TasksManager._TASKS_TO_AUTOMODELS + + if (framework, model_type, task) in TasksManager._CUSTOM_CLASSES: + library, class_name = TasksManager._CUSTOM_CLASSES[(framework, model_type, task)] + loaded_library = importlib.import_module(library) + + return getattr(loaded_library, class_name) else: - tasks_to_automodel = TasksManager._TASKS_TO_TF_AUTOMODELS - if task not in tasks_to_automodel: - raise KeyError( - f"Unknown task: {task}. Possible values are: " - + ", ".join([f"`{key}` for {tasks_to_automodel[key]}" for key in tasks_to_automodel]) - ) + if framework == "pt": + tasks_to_automodel = TasksManager._TASKS_TO_AUTOMODELS + else: + tasks_to_automodel = TasksManager._TASKS_TO_TF_AUTOMODELS + + if task not in tasks_to_automodel: + raise KeyError( + f"Unknown task: {task}. Possible values are: " + + ", ".join([f"`{key}` for {tasks_to_automodel[key]}" for key in tasks_to_automodel]) + ) + + library = TasksManager._TASKS_TO_LIBRARY[task] + loaded_library = importlib.import_module(library) - module = importlib.import_module(TasksManager._TASKS_TO_LIBRARY[task]) - return getattr(module, tasks_to_automodel[task]) + if model_class_name is None: + if isinstance(tasks_to_automodel[task], str): + model_class_name = tasks_to_automodel[task] + else: + # automatic-speech-recognition case, which may map to several auto class + if library == "transformers": + if model_type is None: + logger.warning( + f"No model type passed for the task {task}, that may be mapped to several loading" + f" classes ({tasks_to_automodel[task]}). Defaulting to {tasks_to_automodel[task][0]}" + " to load the model." + ) + model_class_name = tasks_to_automodel[task][0] + else: + for autoclass_name in tasks_to_automodel[task]: + module = getattr(loaded_library, autoclass_name) + # TODO: we must really get rid of this - and _ mess + if ( + model_type in module._model_mapping._model_mapping + or model_type.replace("-", "_") in module._model_mapping._model_mapping + ): + model_class_name = autoclass_name + break + + if model_class_name is None: + raise ValueError( + f"Unrecognized configuration classes {tasks_to_automodel[task]} do not match" + f" with the model type {model_type} and task {task}." + ) + else: + raise NotImplementedError( + "For library other than transformers, the _TASKS_TO_AUTOMODELS mapping should be one to one." + ) + + return getattr(loaded_library, model_class_name) @staticmethod def determine_framework( @@ -1118,11 +1215,11 @@ def _infer_task_from_model_name_or_path( if "stable-diffusion" in model_info.tags: inferred_task_name = "stable-diffusion" else: - transformers_info = model_info.transformersInfo - if model_info.config["model_type"] == "vision-encoder-decoder": - inferred_task_name = "vision2seq-lm" - # TODO: handle other possible special cases here. + if getattr(model_info, "pipeline_tag", None) is not None: + inferred_task_name = model_info.pipeline_tag else: + transformers_info = model_info.transformersInfo + if transformers_info is None or transformers_info.get("auto_model") is None: raise RuntimeError(f"Could not infer the task from the model repo {model_name_or_path}") auto_model_class_name = transformers_info["auto_model"] @@ -1229,9 +1326,27 @@ def get_model_from_task( """ framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) + + original_task = task if task == "auto": task = TasksManager.infer_task_from_model(model_name_or_path, subfolder=subfolder, revision=revision) - model_class = TasksManager.get_model_class_for_task(task, framework) + + model_type = None + model_class_name = None + if TasksManager._TASKS_TO_LIBRARY[task.replace("-with-past", "")] == "transformers": + # TODO: if automatic-speech-recognition is passed as task, it may map to several + # different auto class (AutoModelForSpeechSeq2Seq or AutoModelForCTC), + # depending on the model type + if original_task in ["auto", "automatic-speech-recognition"]: + config = AutoConfig.from_pretrained(model_name_or_path) + model_type = config.model_type.replace("_", "-") + if original_task == "auto" and config.architectures is not None: + model_class_name = config.architectures[0] + + model_class = TasksManager.get_model_class_for_task( + task, framework, model_type=model_type, model_class_name=model_class_name + ) + kwargs = {"subfolder": subfolder, "revision": revision, "cache_dir": cache_dir, **model_kwargs} try: if framework == "pt": @@ -1252,7 +1367,7 @@ def get_model_from_task( def get_exporter_config_constructor( exporter: str, model: Optional[Union["PreTrainedModel", "TFPreTrainedModel"]] = None, - task: str = "default", + task: str = "feature-extraction", model_type: Optional[str] = None, model_name: Optional[str] = None, exporter_config_kwargs: Optional[Dict[str, Any]] = None, @@ -1265,7 +1380,7 @@ def get_exporter_config_constructor( The exporter to use. model (`Optional[Union[PreTrainedModel, TFPreTrainedModel]]`, defaults to `None`): The instance of the model. - task (`str`, defaults to `"default"`): + task (`str`, defaults to `"feature-extraction"`): The task to retrieve the config for. model_type (`Optional[str]`, defaults to `None`): The model type to retrieve the config for. @@ -1290,6 +1405,8 @@ def get_exporter_config_constructor( model_name = getattr(model, "name", model_name) model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter, model_name=model_name) + + task = TasksManager.map_from_synonym(task) if task not in model_tasks: raise ValueError( f"{model_type} doesn't support task {task} for the {exporter} backend." diff --git a/optimum/exporters/tflite/base.py b/optimum/exporters/tflite/base.py index 934b444d78..01f93f1289 100644 --- a/optimum/exporters/tflite/base.py +++ b/optimum/exporters/tflite/base.py @@ -83,10 +83,10 @@ class TFLiteQuantizationConfig: smallest split will be used. primary_key (`Optional[str]`, defaults `None`): The name of the column in the dataset containing the main data to preprocess. Only for - sequence-classification and token-classification. + text-classification and token-classification. secondary_key (`Optional[str]`, defaults `None`): The name of the second column in the dataset containing the main data to preprocess, not always needed. - Only for sequence-classification and token-classification. + Only for text-classification and token-classification. question_key (`Optional[str]`, defaults `None`): The name of the column containing the question in the dataset. Only for question-answering. context_key (`Optional[str]`, defaults `None`): @@ -141,7 +141,7 @@ class TFLiteConfig(ExportConfig, ABC): Args: config (`transformers.PretrainedConfig`): The model configuration. - task (`str`, defaults to `"default"`): + task (`str`, defaults to `"feature-extraction"`): The task the model should be exported for. The rest of the arguments are used to specify the shape of the inputs the model can take. @@ -157,23 +157,22 @@ class TFLiteConfig(ExportConfig, ABC): ] = tuple(approach for approach in QuantizationApproach) _TASK_TO_COMMON_OUTPUTS = { - "causal-lm": ["logits"], - "default": ["last_hidden_state"], + "text-generation": ["logits"], + "feature-extraction": ["last_hidden_state"], "image-classification": ["logits"], "image-segmentation": ["logits", "pred_boxes", "pred_masks"], "masked-im": ["logits"], - "masked-lm": ["logits"], + "fill-mask": ["logits"], "multiple-choice": ["logits"], "object-detection": ["logits", "pred_boxes"], "question-answering": ["start_logits", "end_logits"], "semantic-segmentation": ["logits"], - "seq2seq-lm": ["logits", "encoder_last_hidden_state"], - "sequence-classification": ["logits"], + "text2text-generation": ["logits", "encoder_last_hidden_state"], + "text-classification": ["logits"], "token-classification": ["logits"], - "speech2seq-lm": ["logits"], + "automatic-speech-recognition": ["logits"], "audio-classification": ["logits"], "audio-frame-classification": ["logits"], - "audio-ctc": ["logits"], "audio-xvector": ["logits"], } diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index c7b15eff54..76e1768113 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -981,7 +981,7 @@ class ORTModelForMaskedLM(ORTModel): + MASKED_LM_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="ORTModelForMaskedLM", - checkpoint="optimum/bert-base-uncased-for-masked-lm", + checkpoint="optimum/bert-base-uncased-for-fill-mask", ) ) def forward( diff --git a/optimum/onnxruntime/runs/utils.py b/optimum/onnxruntime/runs/utils.py index 84d003302e..6d31cf6570 100644 --- a/optimum/onnxruntime/runs/utils.py +++ b/optimum/onnxruntime/runs/utils.py @@ -9,7 +9,7 @@ task_ortmodel_map = { - "causal-lm": ORTModelForCausalLM, + "text-generation": ORTModelForCausalLM, "feature-extraction": ORTModelForFeatureExtraction, "image-classification": ORTModelForImageClassification, "question-answering": ORTModelForQuestionAnswering, diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index a28e36a275..6108da3bd6 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -172,19 +172,19 @@ def config(self): class ORTFeaturesManager: _TASKS_TO_ORTMODELS = { - "default": ORTModelForFeatureExtraction, - "masked-lm": ORTModelForMaskedLM, - "causal-lm": ORTModelForCausalLM, - "causal-lm-with-past": ORTModelForCausalLM, - "seq2seq-lm": ORTModelForSeq2SeqLM, - "seq2seq-lm-with-past": ORTModelForSeq2SeqLM, - "sequence-classification": ORTModelForSequenceClassification, + "feature-extraction": ORTModelForFeatureExtraction, + "fill-mask": ORTModelForMaskedLM, + "text-generation": ORTModelForCausalLM, + "text-generation-with-past": ORTModelForCausalLM, + "text2text-generation": ORTModelForSeq2SeqLM, + "text2text-generation-with-past": ORTModelForSeq2SeqLM, + "text-classification": ORTModelForSequenceClassification, "token-classification": ORTModelForTokenClassification, "multiple-choice": ORTModelForMultipleChoice, "question-answering": ORTModelForQuestionAnswering, "image-classification": ORTModelForImageClassification, "semantic-segmentation": ORTModelForSemanticSegmentation, - "speech2seq-lm": ORTModelForSpeechSeq2Seq, + "automatic-speech-recognition": ORTModelForSpeechSeq2Seq, } SUPPORTED_FEATURES = _TASKS_TO_ORTMODELS.keys() @@ -289,7 +289,7 @@ def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, - feature: str = "default", + feature: str = "feature-extraction", args: ORTTrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, @@ -1087,10 +1087,10 @@ def evaluation_loop_ort( # Load ORT model support_loss_in_modeling = self.feature in [ - "causal-lm", - "causal-lm-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", ] support_feature = self.feature in ORTFeaturesManager.SUPPORTED_FEATURES if support_loss_in_modeling or (not self.exported_with_loss and support_feature): @@ -1318,10 +1318,10 @@ def prediction_loop_ort( # Load ORT model support_loss_in_modeling = self.feature in [ - "causal-lm", - "causal-lm-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", ] support_feature = self.feature in ORTFeaturesManager.SUPPORTED_FEATURES if support_loss_in_modeling or (not self.exported_with_loss and support_feature): @@ -1547,7 +1547,7 @@ def compute_loss_ort(self, model, inputs, return_outputs=False): self._past = outputs[self.args.past_index] if labels is not None: - if "causal-lm" in self.feature: + if "text-generation" in self.feature: loss = self.label_smoother(outputs, labels, shift_labels=True) else: loss = self.label_smoother(outputs, labels) diff --git a/optimum/utils/preprocessing/task_processors_manager.py b/optimum/utils/preprocessing/task_processors_manager.py index 4a19713b6a..2720ed41fb 100644 --- a/optimum/utils/preprocessing/task_processors_manager.py +++ b/optimum/utils/preprocessing/task_processors_manager.py @@ -28,7 +28,7 @@ class TaskProcessorsManager: _TASK_TO_DATASET_PROCESSING_CLASS = { - "sequence-classification": TextClassificationProcessing, + "text-classification": TextClassificationProcessing, "token-classification": TokenClassificationProcessing, "question-answering": QuestionAnsweringProcessing, "image-classification": ImageClassificationProcessing, diff --git a/tests/benchmark/benchmark_bettertransformer.py b/tests/benchmark/benchmark_bettertransformer.py index 38517d66d2..761053dc26 100644 --- a/tests/benchmark/benchmark_bettertransformer.py +++ b/tests/benchmark/benchmark_bettertransformer.py @@ -167,9 +167,9 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t task = TasksManager.infer_task_from_model(args.model_name) - if task == "causal-lm": + if task == "text-generation": autoclass = AutoModelForCausalLM - elif task == "seq2seq-lm": + elif task == "text2text-generation": autoclass = AutoModelForSeq2SeqLM else: autoclass = AutoModel diff --git a/tests/benchmark/profile_bettertransformer_t5.py b/tests/benchmark/profile_bettertransformer_t5.py index 93f020e511..f5de73aea5 100644 --- a/tests/benchmark/profile_bettertransformer_t5.py +++ b/tests/benchmark/profile_bettertransformer_t5.py @@ -157,9 +157,9 @@ def profile_model(model, profile_name, input_ids, masks, num_batches, is_decoder task = TasksManager.infer_task_from_model(args.model_name) - if task == "causal-lm": + if task == "text-generation": autoclass = AutoModelForCausalLM - elif task == "seq2seq-lm": + elif task == "text2text-generation": autoclass = AutoModelForSeq2SeqLM else: autoclass = AutoModel diff --git a/tests/exporters/common/test_tasks_manager.py b/tests/exporters/common/test_tasks_manager.py index ac583a1219..88865326f3 100644 --- a/tests/exporters/common/test_tasks_manager.py +++ b/tests/exporters/common/test_tasks_manager.py @@ -17,7 +17,9 @@ from typing import Optional, Set from unittest import TestCase -from transformers import BertConfig +import pytest +from transformers import BertConfig, VisualBertForQuestionAnswering +from transformers.testing_utils import slow from optimum.exporters import TasksManager from optimum.exporters.onnx.model_configs import BertOnnxConfig @@ -71,14 +73,14 @@ def test_register(self): # Case 1: We try to register a config that was already registered, it should not register anything. register_for_onnx = TasksManager.create_register("onnx") - @register_for_onnx("bert", "sequence-classification") + @register_for_onnx("bert", "text-classification") class BadBertOnnxConfig(BertOnnxConfig): pass bert_config_constructor = TasksManager.get_exporter_config_constructor( "onnx", model_type="bert", - task="sequence-classification", + task="text-classification", ) bert_onnx_config = bert_config_constructor(BertConfig()) @@ -92,14 +94,14 @@ class BadBertOnnxConfig(BertOnnxConfig): # the new config. register_for_onnx = TasksManager.create_register("onnx", overwrite_existing=True) - @register_for_onnx("bert", "sequence-classification") + @register_for_onnx("bert", "text-classification") class BadBertOnnxConfig2(BertOnnxConfig): pass bert_config_constructor = TasksManager.get_exporter_config_constructor( "onnx", model_type="bert", - task="sequence-classification", + task="text-classification", ) bert_onnx_config = bert_config_constructor(BertConfig()) @@ -122,14 +124,14 @@ class UnknownTask(BertOnnxConfig): # Case 4: Registering for a new backend. register_for_new_backend = TasksManager.create_register("new-backend") - @register_for_new_backend("bert", "sequence-classification") + @register_for_new_backend("bert", "text-classification") class BertNewBackendConfig(BertOnnxConfig): pass bert_config_constructor = TasksManager.get_exporter_config_constructor( "new-backend", model_type="bert", - task="sequence-classification", + task="text-classification", ) bert_onnx_config = bert_config_constructor(BertConfig()) @@ -154,3 +156,12 @@ class BertNewBackendConfigTaskSpecific(BertOnnxConfig): BertNewBackendConfigTaskSpecific, "Wrong config class compared to the registered one.", ) + + @slow + @pytest.mark.run_slow + def test_custom_class(self): + task = TasksManager.infer_task_from_model("google/pix2struct-base") + self.assertEqual(task, "image-to-text") + + model = TasksManager.get_model_from_task("question-answering", "uclanlp/visualbert-vqa") + self.assertTrue(isinstance(model, VisualBertForQuestionAnswering)) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 850d67c3d4..3520949e91 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -79,7 +79,7 @@ # "owlvit": "google/owlvit-base-patch32", "pegasus": "hf-internal-testing/tiny-random-PegasusModel", "perceiver": { - "hf-internal-testing/tiny-random-language_perceiver": ["masked-lm", "sequence-classification"], + "hf-internal-testing/tiny-random-language_perceiver": ["fill-mask", "text-classification"], "hf-internal-testing/tiny-random-vision_perceiver_conv": ["image-classification"], }, # "rembert": "google/rembert", @@ -100,7 +100,11 @@ "wav2vec2": "hf-internal-testing/tiny-random-Wav2Vec2Model", "wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer", "wavlm": { - "hf-internal-testing/tiny-random-wavlm": ["default", "audio-ctc", "audio-classification"], + "hf-internal-testing/tiny-random-wavlm": [ + "feature-extraction", + "automatic-speech-recognition", + "audio-classification", + ], "hf-internal-testing/tiny-random-WavLMForCTC": ["audio-frame-classification"], "hf-internal-testing/tiny-random-WavLMForXVector": ["audio-xvector"], }, @@ -108,7 +112,11 @@ "sew-d": "hf-internal-testing/tiny-random-SEWDModel", "unispeech": "hf-internal-testing/tiny-random-unispeech", "unispeech-sat": { - "hf-internal-testing/tiny-random-unispeech-sat": ["default", "audio-ctc", "audio-classification"], + "hf-internal-testing/tiny-random-unispeech-sat": [ + "feature-extraction", + "automatic-speech-recognition", + "audio-classification", + ], "hf-internal-testing/tiny-random-UniSpeechSatForPreTraining": ["audio-frame-classification"], "hf-internal-testing/tiny-random-UniSpeechSatForXVector": ["audio-xvector"], }, @@ -120,10 +128,10 @@ "xlm-roberta": "hf-internal-testing/tiny-xlm-roberta", "vision-encoder-decoder": { "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2": [ - "vision2seq-lm", - "vision2seq-lm-with-past", + "image-to-text", + "image-to-text-with-past", ], - "microsoft/trocr-small-handwritten": ["vision2seq-lm"], + "microsoft/trocr-small-handwritten": ["image-to-text"], }, } diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 9b6a78ad1c..412ef71b53 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -59,7 +59,12 @@ def _get_models_to_test(export_models_dict: Dict): # -with-past and monolith cases are absurd, so we don't test them as not supported if any( task == ort_special_task - for ort_special_task in ["causal-lm", "seq2seq-lm", "speech2seq-lm", "vision2seq-lm"] + for ort_special_task in [ + "text-generation", + "text2text-generation", + "automatic-speech-recognition", + "image-to-text", + ] ): models_to_test.append( (f"{model_type}_{task}_monolith", model_type, model_name, task, True, False) @@ -67,18 +72,18 @@ def _get_models_to_test(export_models_dict: Dict): # For other tasks, we don't test --no-post-process as there is none anyway if task in [ - "default-with-past", - "causal-lm-with-past", - "speech2seq-lm-with-past", - "vision2seq-lm-with-past", - "seq2seq-lm-with-past", + "feature-extraction-with-past", + "text-generation-with-past", + "automatic-speech-recognition-with-past", + "image-to-text-with-past", + "text2text-generation-with-past", ]: models_to_test.append( (f"{model_type}_{task}_no_postprocess", model_type, model_name, task, False, True) ) # TODO: segformer task can not be automatically inferred - # TODO: xlm-roberta model auto-infers causal-lm, but we don't support it + # TODO: xlm-roberta model auto-infers text-generation, but we don't support it # TODO: perceiver auto-infers default, but we don't support it (why?) if model_type not in ["segformer", "xlm-roberta", "perceiver", "vision-encoder-decoder"]: models_to_test.append((f"{model_type}_no_task", model_type, model_name, "auto", False, False)) @@ -192,7 +197,7 @@ def test_external_data(self, use_cache: bool): os.environ["FORCE_ONNX_EXTERNAL_DATA"] = "1" # force exporting small model with external data with TemporaryDirectory() as tmpdirname: - task = "seq2seq-lm" + task = "text2text-generation" if use_cache: task += "-with-past" @@ -218,7 +223,7 @@ def test_external_data(self, use_cache: bool): def test_trust_remote_code(self): with TemporaryDirectory() as tmpdirname: out = subprocess.run( - f"python3 -m optimum.exporters.onnx --model fxmarty/tiny-testing-gpt2-remote-code --task causal-lm {tmpdirname}", + f"python3 -m optimum.exporters.onnx --model fxmarty/tiny-testing-gpt2-remote-code --task text-generation {tmpdirname}", shell=True, capture_output=True, ) @@ -227,7 +232,7 @@ def test_trust_remote_code(self): with TemporaryDirectory() as tmpdirname: out = subprocess.run( - f"python3 -m optimum.exporters.onnx --trust-remote-code --model fxmarty/tiny-testing-gpt2-remote-code --task causal-lm {tmpdirname}", + f"python3 -m optimum.exporters.onnx --trust-remote-code --model fxmarty/tiny-testing-gpt2-remote-code --task text-generation {tmpdirname}", shell=True, check=True, ) @@ -270,3 +275,32 @@ def test_export_on_fp16( self.skipTest("ibert can not be supported in fp16") self._onnx_export(model_name, task, monolith, no_post_process, fp16=True) + + @parameterized.expand( + [ + ["causal-lm", "gpt2"], + ["causal-lm-with-past", "gpt2"], + ["seq2seq-lm", "t5"], + ["seq2seq-lm-with-past", "t5"], + ["speech2seq-lm", "whisper"], + ["speech2seq-lm-with-past", "whisper"], + ["vision2seq-lm", "vision-encoder-decoder"], + ["sequence-classification", "bert"], + ["masked-lm", "bert"], + ["default", "blenderbot"], + ["default-with-past", "blenderbot"], + ["audio-ctc", "wav2vec2-conformer"], + ] + ) + @slow + @pytest.mark.run_slow + def test_synonym_tasks_backward_compatibility(self, task: str, model_type: str): + model_name = PYTORCH_EXPORT_MODELS_TINY[model_type] + + if isinstance(model_name, dict): + for _model_name in model_name.keys(): + with TemporaryDirectory() as tmpdir: + main_export(model_name_or_path=_model_name, output=tmpdir, task=task) + else: + with TemporaryDirectory() as tmpdir: + main_export(model_name_or_path=model_name, output=tmpdir, task=task) diff --git a/tests/exporters/onnx/test_onnx_config_loss.py b/tests/exporters/onnx/test_onnx_config_loss.py index ad7866108c..1eed7d9b61 100644 --- a/tests/exporters/onnx/test_onnx_config_loss.py +++ b/tests/exporters/onnx/test_onnx_config_loss.py @@ -55,7 +55,7 @@ def test_onnx_config_with_loss(self): with self.subTest(model=model): with tempfile.TemporaryDirectory() as tmp_dir: onnx_config_constructor = TasksManager.get_exporter_config_constructor( - model=model, exporter="onnx", task="sequence-classification" + model=model, exporter="onnx", task="text-classification" ) onnx_config = onnx_config_constructor(model.config) @@ -84,7 +84,7 @@ def test_onnx_config_with_loss(self): framework = "pt" if isinstance(model, PreTrainedModel) else "tf" normalized_config = NormalizedConfigManager.get_normalized_config_class("bert")(model.config) input_generator = DummyTextInputGenerator( - "sequence-classification", normalized_config, batch_size=2, sequence_length=16 + "text-classification", normalized_config, batch_size=2, sequence_length=16 ) inputs = { @@ -130,7 +130,7 @@ def test_onnx_decoder_model_with_config_with_loss(self): # Wrap OnnxConfig onnx_config_constructor = TasksManager.get_exporter_config_constructor( - model=model, exporter="onnx", task="sequence-classification" + model=model, exporter="onnx", task="text-classification" ) onnx_config = onnx_config_constructor(model.config) wrapped_onnx_config = OnnxConfigWithLoss(onnx_config) @@ -181,7 +181,7 @@ def test_onnx_seq2seq_model_with_config_with_loss(self): # Wrap OnnxConfig(decoders) onnx_config_constructor = TasksManager.get_exporter_config_constructor( - model=model, exporter="onnx", task="seq2seq-lm" + model=model, exporter="onnx", task="text2text-generation" ) onnx_config = onnx_config_constructor(model.config) diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index ded5ceafb0..5bc08a51db 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -168,7 +168,12 @@ def _get_models_to_test(export_models_dict: Dict): if any( task == ort_special_task - for ort_special_task in ["causal-lm", "seq2seq-lm", "speech2seq-lm", "vision2seq-lm"] + for ort_special_task in [ + "text-generation", + "text2text-generation", + "automatic-speech-recognition", + "image-to-text", + ] ): models_to_test.append( ( @@ -219,7 +224,7 @@ def _onnx_export( if ( isinstance(onnx_config, OnnxConfigWithPast) and getattr(model.config, "pad_token_id", None) is None - and task == "sequence-classification" + and task == "text-classification" ): model.config.pad_token_id = 0 @@ -238,11 +243,18 @@ def _onnx_export( if ( model.config.is_encoder_decoder - and task.startswith(("seq2seq-lm", "speech2seq-lm", "vision2seq-lm", "default-with-past")) + and task.startswith( + ( + "text2text-generation", + "automatic-speech-recognition", + "image-to-text", + "feature-extraction-with-past", + ) + ) and monolith is False ): models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config) - elif task.startswith("causal-lm") and monolith is False: + elif task.startswith("text-generation") and monolith is False: models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config) else: models_and_onnx_configs = {"model": (model, onnx_config)} diff --git a/tests/exporters/tflite/test_exporters_tflite_cli.py b/tests/exporters/tflite/test_exporters_tflite_cli.py index 5bdf5dd5ad..9adc638234 100644 --- a/tests/exporters/tflite/test_exporters_tflite_cli.py +++ b/tests/exporters/tflite/test_exporters_tflite_cli.py @@ -288,7 +288,7 @@ def test_exporters_cli_tflite_int8_quantization_with_custom_dataset( def test_trust_remote_code(self): with TemporaryDirectory() as tmpdirname: out = subprocess.run( - f"python3 -m optimum.exporters.tflite --model fxmarty/tiny-testing-gpt2-remote-code --task causal-lm {tmpdirname}", + f"python3 -m optimum.exporters.tflite --model fxmarty/tiny-testing-gpt2-remote-code --task text-generation {tmpdirname}", shell=True, capture_output=True, ) @@ -297,7 +297,7 @@ def test_trust_remote_code(self): with TemporaryDirectory() as tmpdirname: out = subprocess.run( - f"python3 -m optimum.exporters.tflite --trust-remote-code --model fxmarty/tiny-testing-gpt2-remote-code --task causal-lm {tmpdirname}", + f"python3 -m optimum.exporters.tflite --trust-remote-code --model fxmarty/tiny-testing-gpt2-remote-code --task text-generation {tmpdirname}", shell=True, check=True, ) diff --git a/tests/onnx/test_onnx_graph_transformations.py b/tests/onnx/test_onnx_graph_transformations.py index 09ccd005e8..bed539eacc 100644 --- a/tests/onnx/test_onnx_graph_transformations.py +++ b/tests/onnx/test_onnx_graph_transformations.py @@ -43,7 +43,7 @@ def test_weight_sharing_output_match(self): tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModel.from_pretrained(model_id) - task = "default" + task = "feature-extraction" with TemporaryDirectory() as tmpdir: subprocess.run( f"python3 -m optimum.exporters.onnx --model {model_id} --task {task} {tmpdir}", @@ -69,10 +69,10 @@ def test_weight_sharing_output_match(self): class OnnxMergingTestCase(TestCase): SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = { - "hf-internal-testing/tiny-random-GPT2Model": "causal-lm-with-past", - "hf-internal-testing/tiny-random-t5": "seq2seq-lm-with-past", - "hf-internal-testing/tiny-random-bart": "seq2seq-lm-with-past", - "openai/whisper-tiny.en": "speech2seq-lm-with-past", + "hf-internal-testing/tiny-random-GPT2Model": "text-generation-with-past", + "hf-internal-testing/tiny-random-t5": "text2text-generation-with-past", + "hf-internal-testing/tiny-random-bart": "text2text-generation-with-past", + "openai/whisper-tiny.en": "automatic-speech-recognition-with-past", } @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) diff --git a/tests/onnxruntime/nightly_test_trainer.py b/tests/onnxruntime/nightly_test_trainer.py index 2b42761b89..38bdfd0797 100644 --- a/tests/onnxruntime/nightly_test_trainer.py +++ b/tests/onnxruntime/nightly_test_trainer.py @@ -73,7 +73,7 @@ } _ENCODER_TASKS_DATASETS_CONFIGS = { - "sequence-classification": { + "text-classification": { "dataset": ["glue", "sst2"], "metric": ["glue", "sst2"], "data_collator": default_data_collator, @@ -87,12 +87,12 @@ } _DECODER_TASKS_DATASETS_CONFIGS = { - "causal-lm": { + "text-generation": { "dataset": ["wikitext", "wikitext-2-raw-v1"], "metric": ["accuracy"], "data_collator": default_data_collator, }, - "causal-lm-with-past": { + "text-generation-with-past": { "dataset": ["wikitext", "wikitext-2-raw-v1"], "metric": ["accuracy"], "data_collator": default_data_collator, @@ -100,12 +100,12 @@ } _SEQ2SEQ_TASKS_DATASETS_CONFIGS = { - "seq2seq-lm": { + "text2text-generation": { "dataset": ["xsum"], "metric": ["rouge"], "data_collator_class": DataCollatorForSeq2Seq, }, - "seq2seq-lm-with-past": { + "text2text-generation-with-past": { "dataset": ["xsum"], "metric": ["rouge"], "data_collator_class": DataCollatorForSeq2Seq, @@ -201,12 +201,12 @@ def get_ort_trainer( def load_and_prepare(feature): preprocess_mapping = { - "sequence-classification": load_and_prepare_glue, + "text-classification": load_and_prepare_glue, "token-classification": load_and_prepare_ner, - "causal-lm": load_and_prepare_clm, - "causal-lm-with-past": load_and_prepare_clm, - "seq2seq-lm": load_and_prepare_xsum, - "seq2seq-lm-with-past": load_and_prepare_xsum, + "text-generation": load_and_prepare_clm, + "text-generation-with-past": load_and_prepare_clm, + "text2text-generation": load_and_prepare_xsum, + "text2text-generation-with-past": load_and_prepare_xsum, } return preprocess_mapping[feature] @@ -653,9 +653,9 @@ def test_trainer_fp16_pt_inference(self, test_name, model_name, feature, data_me @parameterized.expand( _get_models_to_test(_ENCODERS_TO_TEST, _ENCODER_TASKS_DATASETS_CONFIGS) # Exclude "with-past" tests as they fail for ORT inference after the mixed-precision training - # + _get_models_to_test(_DECODERS_TO_TEST, _DECODER_TASKS_DATASETS_CONFIGS, excluded=["causal-lm-with-past"]) # Skip test for OOM bug + # + _get_models_to_test(_DECODERS_TO_TEST, _DECODER_TASKS_DATASETS_CONFIGS, excluded=["text-generation-with-past"]) # Skip test for OOM bug + _get_models_to_test( - _SEQ2SEQ_MODELS_TO_TEST, _SEQ2SEQ_TASKS_DATASETS_CONFIGS, excluded=["seq2seq-lm-with-past"] + _SEQ2SEQ_MODELS_TO_TEST, _SEQ2SEQ_TASKS_DATASETS_CONFIGS, excluded=["text2text-generation-with-past"] ), skip_on_empty=True, ) @@ -894,7 +894,7 @@ def setUp(self): self.weight_decay = 0.01 self.model_name = "bert-base-cased" - self.feature = "sequence-classification" + self.feature = "text-classification" def check_optim_and_kwargs(self, optim: OptimizerNames, mandatory_kwargs, expected_cls): args = ORTTrainingArguments(optim=optim, output_dir="None") diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 1384e3c037..52bb1a1984 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1271,7 +1271,7 @@ class ORTModelForMaskedLMIntegrationTest(ORTModelTestMixin): FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} ORTMODEL_CLASS = ORTModelForMaskedLM - TASK = "masked-lm" + TASK = "fill-mask" def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: @@ -1433,7 +1433,7 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin): FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} ORTMODEL_CLASS = ORTModelForSequenceClassification - TASK = "sequence-classification" + TASK = "text-classification" def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: @@ -1735,7 +1735,7 @@ class ORTModelForFeatureExtractionIntegrationTest(ORTModelTestMixin): FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} ORTMODEL_CLASS = ORTModelForFeatureExtraction - TASK = "default" + TASK = "feature-extraction" @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -1979,7 +1979,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): } ORTMODEL_CLASS = ORTModelForCausalLM - TASK = "causal-lm" + TASK = "text-generation" GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.1 @@ -2040,7 +2040,7 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_transformers_and_save(self, model_arch): - if "causal-lm-with-past" not in TasksManager.get_supported_tasks_for_model_type( + if "text-generation-with-past" not in TasksManager.get_supported_tasks_for_model_type( model_arch.replace("_", "-"), exporter="onnx" ): self.skipTest("Unsupported -with-past export case") @@ -2059,7 +2059,7 @@ def test_merge_from_transformers_and_save(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_onnx_and_save(self, model_arch): model_id = MODEL_NAMES[model_arch] - task = "causal-lm-with-past" + task = "text-generation-with-past" if task not in TasksManager.get_supported_tasks_for_model_type(model_arch.replace("_", "-"), exporter="onnx"): self.skipTest("Unsupported export case") @@ -3066,7 +3066,7 @@ class ORTModelForSeq2SeqLMIntegrationTest(ORTModelTestMixin): } ORTMODEL_CLASS = ORTModelForSeq2SeqLM - TASK = "seq2seq-lm" + TASK = "text2text-generation" GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.1 @@ -3111,7 +3111,7 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_transformers_and_save(self, model_arch): - if "seq2seq-lm-with-past" not in TasksManager.get_supported_tasks_for_model_type( + if "text2text-generation-with-past" not in TasksManager.get_supported_tasks_for_model_type( model_arch.replace("_", "-"), exporter="onnx" ): self.skipTest("Unsupported -with-past export case") @@ -3131,7 +3131,7 @@ def test_merge_from_transformers_and_save(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_onnx_and_save(self, model_arch): model_id = MODEL_NAMES[model_arch] - task = "seq2seq-lm-with-past" + task = "text2text-generation-with-past" if task not in TasksManager.get_supported_tasks_for_model_type(model_arch.replace("_", "-"), exporter="onnx"): self.skipTest("Unsupported export case") @@ -3523,7 +3523,7 @@ class ORTModelForSpeechSeq2SeqIntegrationTest(ORTModelTestMixin): } ORTMODEL_CLASS = ORTModelForSpeechSeq2Seq - TASK = "speech2seq-lm" + TASK = "automatic-speech-recognition" GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.1 @@ -3538,7 +3538,7 @@ def _generate_random_audio_data(self): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_transformers_and_save(self, model_arch): - if "speech2seq-lm-with-past" not in TasksManager.get_supported_tasks_for_model_type( + if "automatic-speech-recognition-with-past" not in TasksManager.get_supported_tasks_for_model_type( model_arch.replace("_", "-"), exporter="onnx" ): self.skipTest("Unsupported -with-past export case") @@ -3558,7 +3558,7 @@ def test_merge_from_transformers_and_save(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_onnx_and_save(self, model_arch): model_id = MODEL_NAMES[model_arch] - task = "speech2seq-lm-with-past" + task = "automatic-speech-recognition-with-past" if task not in TasksManager.get_supported_tasks_for_model_type(model_arch.replace("_", "-"), exporter="onnx"): self.skipTest("Unsupported export case") @@ -3924,7 +3924,7 @@ class ORTModelForVision2SeqIntegrationTest(ORTModelTestMixin): ORTMODEL_CLASS = ORTModelForVision2Seq - TASK = "vision2seq-lm" + TASK = "image-to-text" GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.1 @@ -4232,17 +4232,17 @@ class TestBothExportersORTModel(unittest.TestCase): @parameterized.expand( [ ["question-answering", ORTModelForQuestionAnsweringIntegrationTest], - ["sequence-classification", ORTModelForSequenceClassificationIntegrationTest], + ["text-classification", ORTModelForSequenceClassificationIntegrationTest], ["token-classification", ORTModelForTokenClassificationIntegrationTest], - ["default", ORTModelForFeatureExtractionIntegrationTest], + ["feature-extraction", ORTModelForFeatureExtractionIntegrationTest], ["multiple-choice", ORTModelForMultipleChoiceIntegrationTest], - ["causal-lm", ORTModelForCausalLMIntegrationTest], + ["text-generation", ORTModelForCausalLMIntegrationTest], ["image-classification", ORTModelForImageClassificationIntegrationTest], ["semantic-segmentation", ORTModelForSemanticSegmentationIntegrationTest], - ["seq2seq-lm", ORTModelForSeq2SeqLMIntegrationTest], - ["speech2seq-lm", ORTModelForSpeechSeq2SeqIntegrationTest], + ["text2text-generation", ORTModelForSeq2SeqLMIntegrationTest], + ["automatic-speech-recognition", ORTModelForSpeechSeq2SeqIntegrationTest], ["audio-classification", ORTModelForAudioClassificationIntegrationTest], - ["audio-ctc", ORTModelForCTCIntegrationTest], + ["automatic-speech-recognition", ORTModelForCTCIntegrationTest], ["audio-xvector", ORTModelForAudioXVectorIntegrationTest], ["audio-frame-classification", ORTModelForAudioFrameClassificationIntegrationTest], ] diff --git a/tests/onnxruntime/test_optimization.py b/tests/onnxruntime/test_optimization.py index 063b11b49f..2d68a947f0 100644 --- a/tests/onnxruntime/test_optimization.py +++ b/tests/onnxruntime/test_optimization.py @@ -206,7 +206,7 @@ def test_optimization_fp16(self): class ORTOptimizerForSeq2SeqLMIntegrationTest(ORTOptimizerTestMixin): - TASK = "seq2seq-lm" + TASK = "text2text-generation" ORTMODEL_CLASS = ORTModelForSeq2SeqLM SUPPORTED_ARCHITECTURES = [ @@ -323,7 +323,7 @@ def test_optimization_levels_gpu(self, test_name: str, model_arch: str, use_cach class ORTOptimizerForCausalLMIntegrationTest(ORTOptimizerTestMixin): - TASK = "causal-lm" + TASK = "text-generation" ORTMODEL_CLASS = ORTModelForCausalLM SUPPORTED_ARCHITECTURES = [ diff --git a/tests/utils/test_task_processors.py b/tests/utils/test_task_processors.py index e598a20bdb..f8a0a6d5a9 100644 --- a/tests/utils/test_task_processors.py +++ b/tests/utils/test_task_processors.py @@ -37,7 +37,7 @@ IMAGE_PROCESSOR = AutoFeatureExtractor.from_pretrained(IMAGE_MODEL_NAME) TASK_TO_NON_DEFAULT_DATASET = { - "sequence-classification": { + "text-classification": { "dataset_args": {"path": "glue", "name": "mnli"}, "dataset_data_keys": {"primary": "premise", "secondary": "hypothesis"}, }, @@ -189,7 +189,7 @@ def test_load_default_dataset(self): class TextClassificationProcessorTest(TestCase, TaskProcessorTestBase): - TASK_NAME = "sequence-classification" + TASK_NAME = "text-classification" CONFIG = CONFIG PREPROCESSOR = TOKENIZER WRONG_PREPROCESSOR = IMAGE_PROCESSOR