Skip to content

Commit

Permalink
Allow to use a custom class in TasksManager & use canonical tasks nam…
Browse files Browse the repository at this point in the history
…es (#967)

* fix task names

* add test

* fix fill-mask

* fix test

* fix fill mask and vision2seq

* fix

* add test

* fix feature extraction

* fix audio ctc

* add warning

* update comment

* fix bug

* fix tests and comments

* nit

* fix test
  • Loading branch information
fxmarty authored Apr 12, 2023
1 parent 2105a8a commit f7f1ef1
Show file tree
Hide file tree
Showing 36 changed files with 680 additions and 492 deletions.
8 changes: 4 additions & 4 deletions docs/source/exporters/onnx/usage_guides/contribute.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ classification head, we could use:
>>> from transformers import AutoConfig

>>> config = AutoConfig.from_pretrained("bert-base-uncased")
>>> onnx_config_for_seq_clf = BertOnnxConfig(config, task="sequence-classification")
>>> onnx_config_for_seq_clf = BertOnnxConfig(config, task="text-classification")
>>> print(onnx_config_for_seq_clf.outputs)
OrderedDict([('logits', {0: 'batch_size'})])
```
Expand Down Expand Up @@ -157,9 +157,9 @@ For BERT, it looks as follows:
```python
"bert": supported_tasks_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"fill-mask",
"text-generation",
"text-classification",
"multiple-choice",
"token-classification",
"question-answering",
Expand Down
6 changes: 3 additions & 3 deletions docs/source/exporters/onnx/usage_guides/export_a_model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Required arguments:
output Path indicating the directory where to store generated ONNX model.

Optional arguments:
--task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among: ['default', 'masked-lm', 'causal-lm', 'seq2seq-lm', 'sequence-classification', 'token-classification', 'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic-segmentation', 'speech2seq-lm', 'audio-classification', 'audio-frame-classification', 'audio-ctc', 'audio-xvector', 'vision2seq-lm', 'stable-diffusion', 'zero-shot-object-detection']. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder.
--task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among: ['default', 'fill-mask', 'text-generation', 'text2text-generation', 'text-classification', 'token-classification', 'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic-segmentation', 'automatic-speech-recognition', 'audio-classification', 'audio-frame-classification', 'automatic-speech-recognition', 'audio-xvector', 'image-to-text', 'stable-diffusion', 'zero-shot-object-detection']. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder.
--monolith Force to export the model as a single ONNX file. By default, the ONNX exporter may break the model in several ONNX files, for example for encoder-decoder models where the encoder should be run only once while the decoder is looped over.
--device DEVICE The device to use to do the export. Defaults to "cpu".
--opset OPSET If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used.
Expand Down Expand Up @@ -198,7 +198,7 @@ Models exported through `optimum-cli export onnx` can be used directly in [`~onn

When exporting a decoder model used for generation, it can be useful to encapsulate in the exported ONNX the [reuse of past keys and values](https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958/2). This allows to avoid recomputing the same intermediate activations during the generation.

In the ONNX export, the past keys/values are reused by default. This behavior corresponds to `--task seq2seq-lm-with-past`, `--task causal-lm-with-past`, or `--task speech2seq-lm-with-past`. If for any purpose you would like to disable the export with past keys/values reuse, passing explicitly to `optimum-cli export onnx` the task `seq2seq-lm`, `causal-lm` or `speech2seq-lm` is required.
In the ONNX export, the past keys/values are reused by default. This behavior corresponds to `--task text2text-generation-with-past`, `--task text-generation-with-past`, or `--task automatic-speech-recognition-with-past`. If for any purpose you would like to disable the export with past keys/values reuse, passing explicitly to `optimum-cli export onnx` the task `text2text-generation`, `text-generation` or `automatic-speech-recognition` is required.

A model exported using past key/values can be reused directly into Optimum's [`~onnxruntime.ORTModel`]:

Expand Down Expand Up @@ -235,7 +235,7 @@ For each model architecture, you can find the list of supported tasks via the [`

>>> distilbert_tasks = list(TasksManager.get_supported_tasks_for_model_type("distilbert", "onnx").keys())
>>> print(distilbert_tasks)
['default', 'masked-lm', 'sequence-classification', 'multiple-choice', 'token-classification', 'question-answering']
['default', 'fill-mask', 'text-classification', 'multiple-choice', 'token-classification', 'question-answering']
```

You can then pass one of these tasks to the `--task` argument in the `optimum-cli export onnx` command, as mentioned above.
22 changes: 11 additions & 11 deletions docs/source/exporters/task_manager.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ It is possible to know which tasks are supported for a model for a given backend
>>> distilbert_tasks = list(TasksManager.get_supported_tasks_for_model_type(model_type, backend).keys())
>>> print(distilbert_tasks)
['default', 'masked-lm', 'sequence-classification', 'multiple-choice', 'token-classification', 'question-answering']
['default', 'fill-mask', 'text-classification', 'multiple-choice', 'token-classification', 'question-answering']
```
</Tip>
Expand All @@ -44,31 +44,31 @@ It is possible to know which tasks are supported for a model for a given backend
| Task | Auto Class |
|--------------------------------------|--------------------------------------|
| `causal-lm`, `causal-lm-with-past` | `AutoModelForCausalLM` |
| `default`, `default-with-past` | `AutoModel` |
| `masked-lm` | `AutoModelForMaskedLM` |
| `text-generation`, `text-generation-with-past` | `AutoModelForCausalLM` |
| `feature-extraction`, `feature-extraction-with-past` | `AutoModel` |
| `fill-mask` | `AutoModelForMaskedLM` |
| `question-answering` | `AutoModelForQuestionAnswering` |
| `seq2seq-lm`, `seq2seq-lm-with-past` | `AutoModelForSeq2SeqLM` |
| `sequence-classification` | `AutoModelForSequenceClassification` |
| `text2text-generation`, `text2text-generation-with-past` | `AutoModelForSeq2SeqLM` |
| `text-classification` | `AutoModelForSequenceClassification` |
| `token-classification` | `AutoModelForTokenClassification` |
| `multiple-choice` | `AutoModelForMultipleChoice` |
| `image-classification` | `AutoModelForImageClassification` |
| `object-detection` | `AutoModelForObjectDetection` |
| `image-segmentation` | `AutoModelForImageSegmentation` |
| `masked-im` | `AutoModelForMaskedImageModeling` |
| `semantic-segmentation` | `AutoModelForSemanticSegmentation` |
| `speech2seq-lm` | `AutoModelForSpeechSeq2Seq` |
| `automatic-speech-recognition` | `AutoModelForSpeechSeq2Seq` |
### TensorFlow
| Task | Auto Class |
|--------------------------------------|----------------------------------------|
| `causal-lm`, `causal-lm-with-past` | `TFAutoModelForCausalLM` |
| `text-generation`, `text-generation-with-past` | `TFAutoModelForCausalLM` |
| `default`, `default-with-past` | `TFAutoModel` |
| `masked-lm` | `TFAutoModelForMaskedLM` |
| `fill-mask` | `TFAutoModelForMaskedLM` |
| `question-answering` | `TFAutoModelForQuestionAnswering` |
| `seq2seq-lm`, `seq2seq-lm-with-past` | `TFAutoModelForSeq2SeqLM` |
| `sequence-classification` | `TFAutoModelForSequenceClassification` |
| `text2text-generation`, `text2text-generation-with-past` | `TFAutoModelForSeq2SeqLM` |
| `text-classification` | `TFAutoModelForSequenceClassification` |
| `token-classification` | `TFAutoModelForTokenClassification` |
| `multiple-choice` | `TFAutoModelForMultipleChoice` |
| `semantic-segmentation` | `TFAutoModelForSemanticSegmentation` |
Expand Down
4 changes: 2 additions & 2 deletions docs/source/exporters/tflite/usage_guides/export_a_model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ Required arguments:

Optional arguments:
--task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on
the model, but are among: ['default', 'masked-lm', 'causal-lm', 'seq2seq-lm', 'sequence-classification', 'token-classification',
the model, but are among: ['default', 'fill-mask', 'text-generation', 'text2text-generation', 'text-classification', 'token-classification',
'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic-
segmentation', 'speech2seq-lm', 'audio-classification', 'audio-frame-classification', 'audio-ctc', 'audio-xvector', 'vision2seq-
segmentation', 'automatic-speech-recognition', 'audio-classification', 'audio-frame-classification', 'automatic-speech-recognition', 'audio-xvector', 'vision2seq-
lm', 'stable-diffusion', 'zero-shot-object-detection']. For decoder models, use `xxx-with-past` to export the model using past key
values in the decoder.
--atol ATOL If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/onnxruntime/usage_guides/trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ empowered by ONNX Runtime. Here is an example of how to use `ORTTrainer` compare
model=model,
args=training_args,
train_dataset=train_dataset,
+ feature="sequence-classification",
+ feature="text-classification",
...
)

Expand Down Expand Up @@ -159,7 +159,7 @@ empowered by ONNX Runtime. Here is an example of how to use `ORTSeq2SeqTrainer`
model=model,
args=training_args,
train_dataset=train_dataset,
+ feature="seq2seq-lm",
+ feature="text2text-generation",
...
)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/quicktour.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ To train transformers with ONNX Runtime's acceleration features, 🤗 Optimum pr
model=model,
args=training_args,
train_dataset=train_dataset,
+ feature="sequence-classification", # The model type to export to ONNX
+ feature="text-classification", # The model type to export to ONNX
...
)

Expand Down
2 changes: 1 addition & 1 deletion examples/onnxruntime/training/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def compute_metrics(eval_preds):
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
feature="causal-lm",
feature="text-generation",
)

# Training
Expand Down
2 changes: 1 addition & 1 deletion examples/onnxruntime/training/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def compute_metrics(eval_preds):
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
feature="masked-lm",
feature="fill-mask",
)

# Training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def compute_metrics(eval_preds):
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
feature="seq2seq-lm",
feature="text2text-generation",
)

# Training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def compute_metrics(p: EvalPrediction):
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=data_collator,
feature="sequence-classification",
feature="text-classification",
)

# Training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def compute_metrics(eval_preds):
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
feature="seq2seq-lm",
feature="text2text-generation",
)

# Training
Expand Down
4 changes: 2 additions & 2 deletions optimum/commands/export/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def parse_args_tflite(parser: "ArgumentParser"):
default=None,
help=(
"The name of the column in the dataset containing the main data to preprocess. "
"Only for sequence-classification and token-classification. "
"Only for text-classification and token-classification. "
),
)
calibration_dataset_group.add_argument(
Expand All @@ -198,7 +198,7 @@ def parse_args_tflite(parser: "ArgumentParser"):
default=None,
help=(
"The name of the second column in the dataset containing the main data to preprocess, not always needed. "
"Only for sequence-classification and token-classification. "
"Only for text-classification and token-classification. "
),
)
calibration_dataset_group.add_argument(
Expand Down
32 changes: 20 additions & 12 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,7 @@ def main_export(
)

original_task = task
# Infer the task
if task == "auto":
try:
task = TasksManager.infer_task_from_model(model_name_or_path)
except KeyError as e:
raise KeyError(
f"The task could not be automatically inferred. Please provide the argument --task with the task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)
task = TasksManager.map_from_synonym(task)

framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)

Expand Down Expand Up @@ -187,6 +180,14 @@ def main_export(
torch_dtype=torch_dtype,
)

if task == "auto":
try:
task = TasksManager.infer_task_from_model(model_name_or_path)
except KeyError as e:
raise KeyError(
f"The task could not be automatically inferred. Please provide the argument --task with the task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

if task != "stable-diffusion" and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(
model.config.model_type.replace("_", "-"), "onnx"
):
Expand Down Expand Up @@ -215,7 +216,7 @@ def main_export(
needs_pad_token_id = (
isinstance(onnx_config, OnnxConfigWithPast)
and getattr(model.config, "pad_token_id", None) is None
and task in ["sequence_classification"]
and task in ["text-classification"]
)
if needs_pad_token_id:
if pad_token_id is not None:
Expand Down Expand Up @@ -265,7 +266,7 @@ def main_export(
model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
model.save_config(output)
else:
if model.config.is_encoder_decoder and task.startswith("causal-lm"):
if model.config.is_encoder_decoder and task.startswith("text-generation"):
raise ValueError(
f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report"
f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model,"
Expand All @@ -275,11 +276,18 @@ def main_export(
onnx_files_subpaths = None
if (
model.config.is_encoder_decoder
and task.startswith(("seq2seq-lm", "speech2seq-lm", "vision2seq-lm", "default-with-past"))
and task.startswith(
(
"text2text-generation",
"automatic-speech-recognition",
"image-to-text",
"feature-extraction-with-past",
)
)
and not monolith
):
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
elif task.startswith("causal-lm") and not monolith:
elif task.startswith("text-generation") and not monolith:
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
else:
models_and_onnx_configs = {"model": (model, onnx_config)}
Expand Down
Loading

0 comments on commit f7f1ef1

Please sign in to comment.