Skip to content

Commit

Permalink
ORT optimizer refactorization (#294)
Browse files Browse the repository at this point in the history
* Refactorization of ORTOptimizer

* Refactorization of ORTModel

* Adapt examples according to refactorization

* Adapt tests

* Fix style

* Remove quantizer modification

* Fix style

* Apply modifications from #270 for quantizer and optimizer to have same behavior

* Add test for optimization of Seq2Seq models

* Fix style

* Add ort config saving when optimizing a model

* Add ort config saving when quantizing a model

* Add tests

* Fix style

* Adapt optimization examples

* Fix readme

* Remove unused parameter

* Adapt quantization examples

* Fix quantized model and ort config saving

* Add documentation

* Add model configuration saving to simplify loading of optimized model

* Fix style

* Fix description

* Fix quantization tests

* Remove opset argument which is onnx config default opset when exporting with ORTModels
  • Loading branch information
echarlaix authored Aug 24, 2022
1 parent 122a9d8 commit fb7e303
Show file tree
Hide file tree
Showing 21 changed files with 495 additions and 508 deletions.
95 changes: 95 additions & 0 deletions docs/source/onnxruntime/optimization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,101 @@ specific language governing permissions and limitations under the License.

🤗 Optimum provides an `optimum.onnxruntime` package that enables you to apply graph optimization on many model hosted on the 🤗 hub using the [ONNX Runtime](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers) model optimization tool.


## Creating an `ORTOptimizer`

The `ORTOptimizer` class is used to optimize your ONNX model. The class can be initialized using the `from_pretrained()` method, which supports different checkpoint formats.

1. Using an already initialized `ORTModelForXXX` class.

```python
>>> from optimum.onnxruntime import ORTOptimizer, ORTModelForTextClassification

# Loading ONNX Model from the Hub
>>> model = ORTModelForTextClassification.from_pretrained("optimum/distilbert-base-uncased-finetuned-sst-2-english")
# Create an optimizer from an ORTModelForXXX
>>> optimizer = ORTOptimizer.from_pretrained(model)
```
2. Using a local ONNX model from a directory.
```python
>>> from optimum.onnxruntime import ORTOptimizer
# This assumes a model.onnx exists in path/to/model
>>> optimizer = ORTOptimizer.from_pretrained("path/to/model")
```


## Optimization examples

Below you will find an easy end-to-end example on how to optimize [distilbert-base-uncased-finetuned-sst-2-english](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english).

```python
>>> from optimum.onnxruntime import ORTOptimizer, ORTModelForSequenceClassification
>>> from optimum.onnxruntime.configuration import OptimizationConfig

>>> model_id = "distilbert-base-uncased-finetuned-sst-2-english"
>>> save_dir = "/tmp/outputs"

# Load a PyTorch model and export it to the ONNX format
>>> model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True)

# Create the optimizer
>>> optimizer = ORTOptimizer.from_pretrained(model)

# Define the optimization strategy by creating the appropriate configuration
>>> optimization_config = OptimizationConfig(
optimization_level=2,
optimize_with_onnxruntime_only=False,
optimize_for_gpu=False,
)

# Optimize the model
>>> optimizer.optimize(save_dir=save_dir, optimization_config=optimization_config)
```


Below you will find an easy end-to-end example on how to optimize a Seq2Seq model [sshleifer/distilbart-cnn-12-6"](https://huggingface.co/sshleifer/distilbart-cnn-12-6).

```python
>>> from optimum.onnxruntime import ORTOptimizer, ORTModelForSeq2SeqLM
>>> from optimum.onnxruntime.configuration import OptimizationConfig
>>> from transformers import AutoTokenizer

>>> model_id = "sshleifer/distilbart-cnn-12-6"
>>> save_dir = "/tmp/outputs"

# Load a PyTorch model and export it to the ONNX format
>>> model = ORTModelForSeq2SeqLM.from_pretrained(model_id, from_transformers=True)

# Create the optimizer
>>> optimizer = ORTOptimizer.from_pretrained(model)

# Define the optimization strategy by creating the appropriate configuration
>>> optimization_config = OptimizationConfig(
optimization_level=2,
optimize_with_onnxruntime_only=False,
optimize_for_gpu=False,
)

# Optimize the model
>>> optimizer.optimize(save_dir=save_dir, optimization_config=optimization_config)

# Load the resulting optimized model
>>> optimized_model = ORTModelForSeq2SeqLM.from_pretrained(
save_dir,
encoder_file_name="encoder_model_optimized.onnx",
decoder_file_name="decoder_model_optimized.onnx",
decoder_file_with_past_name="decoder_with_past_model_optimized.onnx",
)
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> tokens = tokenizer("This is a sample input", return_tensors="pt")
>>> outputs = optimized_model.generate(**tokens)
```


## ORTOptimizer

[[autodoc]] onnxruntime.optimization.ORTOptimizer
Expand Down
2 changes: 1 addition & 1 deletion docs/source/onnxruntime/quantization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ The `ORTQuantizer` class can be used to statically quantize your ONNX model. Bel
)
```

## Quantize Seq2Seq models.
## Quantize Seq2Seq models

The `ORTQuantizer` currently doesn't support multi-file models, like `ORTModelForSeq2SeqLM`. If you want to quantize a Seq2Seq model, you have to quantize each model's component individually using the `ORTQuantizer` class. Currently, only dynamic quantization is supported for Seq2Seq model.

Expand Down
38 changes: 14 additions & 24 deletions examples/onnxruntime/optimization/multiple-choice/run_swag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
import torch
import transformers
from datasets import load_dataset
from transformers import HfArgumentParser, TrainingArguments
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version

from optimum.onnxruntime import ORTModel, ORTOptimizer
from optimum.onnxruntime import ORTModelForMultipleChoice, ORTOptimizer
from optimum.onnxruntime.configuration import OptimizationConfig, ORTConfig
from optimum.onnxruntime.model import ORTModel


# Will error if the minimal version of Transformers is not installed. The version of transformers must be >= 4.19.0
Expand Down Expand Up @@ -132,10 +133,6 @@ class OptimizationArguments:
Arguments pertaining to what type of optimization we are going to apply on the model.
"""

opset: Optional[int] = field(
default=None,
metadata={"help": "ONNX opset version to export the model with."},
)
optimization_level: Optional[int] = field(
default=1,
metadata={
Expand Down Expand Up @@ -224,7 +221,9 @@ def main():

os.makedirs(training_args.output_dir, exist_ok=True)
model_path = os.path.join(training_args.output_dir, "model.onnx")
optimized_model_path = os.path.join(training_args.output_dir, "model-optimized.onnx")
optimized_model_path = os.path.join(training_args.output_dir, "model_optimized.onnx")

tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name or model_args.model_name_or_path)

# Create the optimization configuration containing all the optimization parameters
optimization_config = OptimizationConfig(
Expand All @@ -233,22 +232,14 @@ def main():
optimize_for_gpu=optim_args.optimize_for_gpu,
)

# Create the optimizer
optimizer = ORTOptimizer.from_pretrained(
model_args.model_name_or_path, feature="multiple-choice", opset=optim_args.opset
)
# Export the model
model = ORTModelForMultipleChoice.from_pretrained(model_args.model_name_or_path, from_transformers=True)

# Export the optimized model
optimizer.export(
onnx_model_path=model_path,
onnx_optimized_model_output_path=optimized_model_path,
optimization_config=optimization_config,
)
# Create the optimizer
optimizer = ORTOptimizer.from_pretrained(model)

# Create the ONNX Runtime configuration summarizing all the parameters related to ONNX IR export and optimization
ort_config = ORTConfig(opset=optimizer.opset, optimization=optimization_config)
# Save the configuration
ort_config.save_pretrained(training_args.output_dir)
# Optimize the model
optimizer.optimize(optimization_config=optimization_config, save_dir=training_args.output_dir)

if training_args.do_eval:
# Prepare the dataset downloading, preprocessing and metric creation to perform the evaluation and / or the
Expand Down Expand Up @@ -313,7 +304,7 @@ def preprocess_function(examples, tokenizer: PreTrainedTokenizerBase):
# Preprocess the evaluation dataset
with training_args.main_process_first(desc="Running tokenizer on the validation dataset"):
eval_dataset = eval_dataset.map(
partial(preprocess_function, tokenizer=optimizer.preprocessor),
partial(preprocess_function, tokenizer=tokenizer),
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
Expand All @@ -330,15 +321,14 @@ def compute_metrics(eval_predictions):

ort_model = ORTModel(
optimized_model_path,
optimizer._onnx_config,
execution_provider=optim_args.execution_provider,
compute_metrics=compute_metrics,
label_names=["label"],
)
outputs = ort_model.evaluation_loop(eval_dataset)

# Save evaluation metrics
with open(os.path.join(training_args.output_dir, f"eval_results.json"), "w") as f:
with open(os.path.join(training_args.output_dir, "eval_results.json"), "w") as f:
json.dump(outputs.metrics, f, indent=4, sort_keys=True)


Expand Down
56 changes: 24 additions & 32 deletions examples/onnxruntime/optimization/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@
import datasets
import transformers
from datasets import load_dataset, load_metric
from transformers import EvalPrediction, HfArgumentParser, PreTrainedTokenizer, TrainingArguments
from transformers import AutoTokenizer, EvalPrediction, HfArgumentParser, PreTrainedTokenizer, TrainingArguments
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

from optimum.onnxruntime import ORTModelForQuestionAnswering, ORTOptimizer
from optimum.onnxruntime.configuration import OptimizationConfig, ORTConfig
from optimum.onnxruntime.model import ORTModel
from optimum.onnxruntime.optimization import ORTOptimizer
from trainer_qa import QuestionAnsweringTrainer
from utils_qa import postprocess_qa_predictions

Expand Down Expand Up @@ -81,10 +81,6 @@ class ModelArguments:
"with private models)."
},
)
execution_provider: str = field(
default="CPUExecutionProvider",
metadata={"help": "ONNX Runtime execution provider to use for inference."},
)


@dataclass
Expand Down Expand Up @@ -204,10 +200,6 @@ class OptimizationArguments:
Arguments pertaining to what type of optimization we are going to apply on the model.
"""

opset: Optional[int] = field(
default=None,
metadata={"help": "ONNX opset version to export the model with."},
)
optimization_level: Optional[int] = field(
default=1,
metadata={
Expand All @@ -233,6 +225,10 @@ class OptimizationArguments:
"GPU or CPU only when optimization_level > 1."
},
)
execution_provider: str = field(
default="CPUExecutionProvider",
metadata={"help": "ONNX Runtime execution provider to use for inference."},
)


def main():
Expand Down Expand Up @@ -264,7 +260,7 @@ def main():
if (
optim_args.optimization_level > 1
and optim_args.optimize_for_gpu
and model_args.execution_provider == "CPUExecutionProvider"
and optim_args.execution_provider == "CPUExecutionProvider"
):
raise ValueError(
f"Optimization level is set at {optim_args.optimization_level} and "
Expand All @@ -275,7 +271,7 @@ def main():
if (
optim_args.optimization_level > 1
and not optim_args.optimize_for_gpu
and model_args.execution_provider == "CUDAExecutionProvider"
and optim_args.execution_provider == "CUDAExecutionProvider"
):
raise ValueError(
f"Optimization level is set at {optim_args.optimization_level} and "
Expand All @@ -293,7 +289,9 @@ def main():

os.makedirs(training_args.output_dir, exist_ok=True)
model_path = os.path.join(training_args.output_dir, "model.onnx")
optimized_model_path = os.path.join(training_args.output_dir, "model-optimized.onnx")
optimized_model_path = os.path.join(training_args.output_dir, "model_optimized.onnx")

tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name or model_args.model_name_or_path)

# Create the optimization configuration containing all the optimization parameters
optimization_config = OptimizationConfig(
Expand All @@ -302,22 +300,14 @@ def main():
optimize_for_gpu=optim_args.optimize_for_gpu,
)

# Create the optimizer
optimizer = ORTOptimizer.from_pretrained(
model_args.model_name_or_path, feature="question-answering", opset=optim_args.opset
)
# Export the model
model = ORTModelForQuestionAnswering.from_pretrained(model_args.model_name_or_path, from_transformers=True)

# Export the optimized model
optimizer.export(
onnx_model_path=model_path,
onnx_optimized_model_output_path=optimized_model_path,
optimization_config=optimization_config,
)
# Create the optimizer
optimizer = ORTOptimizer.from_pretrained(model)

# Create the ONNX Runtime configuration summarizing all the parameters related to ONNX IR export and optimization
ort_config = ORTConfig(opset=optimizer.opset, optimization=optimization_config)
# Save the configuration
ort_config.save_pretrained(training_args.output_dir)
# Optimize the model
optimizer.optimize(optimization_config=optimization_config, save_dir=training_args.output_dir)

# Prepare the dataset downloading, preprocessing and metric creation to perform the evaluation and / or the
# prediction step(s)
Expand Down Expand Up @@ -456,7 +446,7 @@ def compute_metrics(p: EvalPrediction):
if data_args.max_eval_samples is not None:
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
eval_dataset = eval_examples.map(
partial(prepare_validation_features, tokenizer=optimizer.preprocessor),
partial(prepare_validation_features, tokenizer=tokenizer),
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
Expand All @@ -469,9 +459,9 @@ def compute_metrics(p: EvalPrediction):

ort_model = ORTModel(
optimized_model_path,
optimizer._onnx_config,
execution_provider=model_args.execution_provider,
execution_provider=optim_args.execution_provider,
compute_metrics=compute_metrics,
label_names=["start_positions", "end_positions"],
)
outputs = ort_model.evaluation_loop(eval_dataset)
predictions = post_processing_function(eval_examples, eval_dataset, outputs.predictions)
Expand All @@ -492,7 +482,7 @@ def compute_metrics(p: EvalPrediction):
if data_args.max_predict_samples is not None:
predict_examples = predict_examples.select(range(data_args.max_predict_samples))
predict_dataset = predict_examples.map(
partial(prepare_validation_features, tokenizer=optimizer.preprocessor),
partial(prepare_validation_features, tokenizer=tokenizer),
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
Expand All @@ -504,7 +494,9 @@ def compute_metrics(p: EvalPrediction):
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))

ort_model = ORTModel(
optimized_model_path, optimizer._onnx_config, execution_provider=model_args.execution_provider
optimized_model_path,
execution_provider=optim_args.execution_provider,
label_names=["start_positions", "end_positions"],
)
outputs = ort_model.evaluation_loop(predict_dataset)
predictions = post_processing_function(predict_examples, predict_dataset, outputs.predictions)
Expand Down
Loading

0 comments on commit fb7e303

Please sign in to comment.