Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORT optimizer refactorization #294

Merged
merged 28 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bef4e10
Refactorization of ORTOptimizer
echarlaix Jul 13, 2022
bf7f38f
Refactorization of ORTModel
echarlaix Jul 13, 2022
16da1cb
Adapt examples according to refactorization
echarlaix Jul 13, 2022
bb8e100
Adapt tests
echarlaix Jul 13, 2022
f744d18
Fix style
echarlaix Jul 13, 2022
3739019
Remove quantizer modification
echarlaix Jul 13, 2022
af56854
Fix style
echarlaix Jul 13, 2022
a2eb07c
Merge branch 'main' into ort-optimizer-refactorization
echarlaix Jul 27, 2022
8f1871f
Merge branch main into feature branch
echarlaix Aug 22, 2022
93f56ff
Apply modifications from #270 for quantizer and optimizer to have sam…
echarlaix Aug 22, 2022
d4873f0
Add test for optimization of Seq2Seq models
echarlaix Aug 22, 2022
e74f425
Fix style
echarlaix Aug 22, 2022
8974aaa
Add ort config saving when optimizing a model
echarlaix Aug 22, 2022
aef3aff
Add ort config saving when quantizing a model
echarlaix Aug 22, 2022
85994e4
Add tests
echarlaix Aug 22, 2022
58b0352
Fix style
echarlaix Aug 22, 2022
dcfbcbc
Adapt optimization examples
echarlaix Aug 23, 2022
864680d
Fix readme
echarlaix Aug 23, 2022
720c05d
Remove unused parameter
echarlaix Aug 23, 2022
d6720b5
Change quantization approach to dynamic in readmes
echarlaix Aug 23, 2022
c4013fd
Adapt quantization examples
echarlaix Aug 23, 2022
d040ffc
Fix quantized model and ort config saving
echarlaix Aug 23, 2022
c1988fb
Add documentation
echarlaix Aug 23, 2022
51d0a96
Add model configuration saving to simplify loading of optimized model
echarlaix Aug 23, 2022
876ab55
Fix style
echarlaix Aug 23, 2022
a949947
Fix readmes description
echarlaix Aug 23, 2022
ec3638a
Fix quantization tests
echarlaix Aug 23, 2022
9ab642b
Remove opset argument which is onnx config default opset when exporti…
echarlaix Aug 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions docs/source/onnxruntime/optimization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,101 @@ specific language governing permissions and limitations under the License.

🤗 Optimum provides an `optimum.onnxruntime` package that enables you to apply graph optimization on many model hosted on the 🤗 hub using the [ONNX Runtime](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers) model optimization tool.


## Creating an `ORTOptimizer`

The `ORTOptimizer` class is used to optimize your ONNX model. The class can be initialized using the `from_pretrained()` method, which supports different checkpoint formats.

1. Using an already initialized `ORTModelForXXX` class.

```python
>>> from optimum.onnxruntime import ORTOptimizer, ORTModelForTextClassification

# Loading ONNX Model from the Hub
>>> model = ORTModelForTextClassification.from_pretrained("optimum/distilbert-base-uncased-finetuned-sst-2-english")

# Create an optimizer from an ORTModelForXXX
>>> optimizer = ORTOptimizer.from_pretrained(model)
```

2. Using a local ONNX model from a directory.

```python
>>> from optimum.onnxruntime import ORTOptimizer

# This assumes a model.onnx exists in path/to/model
>>> optimizer = ORTOptimizer.from_pretrained("path/to/model")
```


## Optimization examples

Below you will find an easy end-to-end example on how to optimize [distilbert-base-uncased-finetuned-sst-2-english](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english).

```python
>>> from optimum.onnxruntime import ORTOptimizer, ORTModelForSequenceClassification
>>> from optimum.onnxruntime.configuration import OptimizationConfig

>>> model_id = "distilbert-base-uncased-finetuned-sst-2-english"
>>> save_dir = "/tmp/outputs"

# Load a PyTorch model and export it to the ONNX format
>>> model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True)

# Create the optimizer
>>> optimizer = ORTOptimizer.from_pretrained(model)

# Define the optimization strategy by creating the appropriate configuration
>>> optimization_config = OptimizationConfig(
optimization_level=2,
optimize_with_onnxruntime_only=False,
optimize_for_gpu=False,
)

# Optimize the model
>>> optimizer.optimize(save_dir=save_dir, optimization_config=optimization_config)
```


Below you will find an easy end-to-end example on how to optimize a Seq2Seq model [sshleifer/distilbart-cnn-12-6"](https://huggingface.co/sshleifer/distilbart-cnn-12-6).

```python
>>> from optimum.onnxruntime import ORTOptimizer, ORTModelForSeq2SeqLM
>>> from optimum.onnxruntime.configuration import OptimizationConfig
>>> from transformers import AutoTokenizer

>>> model_id = "sshleifer/distilbart-cnn-12-6"
>>> save_dir = "/tmp/outputs"

# Load a PyTorch model and export it to the ONNX format
>>> model = ORTModelForSeq2SeqLM.from_pretrained(model_id, from_transformers=True)

# Create the optimizer
>>> optimizer = ORTOptimizer.from_pretrained(model)

# Define the optimization strategy by creating the appropriate configuration
>>> optimization_config = OptimizationConfig(
optimization_level=2,
optimize_with_onnxruntime_only=False,
optimize_for_gpu=False,
)

# Optimize the model
>>> optimizer.optimize(save_dir=save_dir, optimization_config=optimization_config)

# Load the resulting optimized model
>>> optimized_model = ORTModelForSeq2SeqLM.from_pretrained(
save_dir,
encoder_file_name="encoder_model_optimized.onnx",
decoder_file_name="decoder_model_optimized.onnx",
decoder_file_with_past_name="decoder_with_past_model_optimized.onnx",
)
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> tokens = tokenizer("This is a sample input", return_tensors="pt")
>>> outputs = optimized_model.generate(**tokens)
```


## ORTOptimizer

[[autodoc]] onnxruntime.optimization.ORTOptimizer
Expand Down
2 changes: 1 addition & 1 deletion docs/source/onnxruntime/quantization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ The `ORTQuantizer` class can be used to statically quantize your ONNX model. Bel
)
```

## Quantize Seq2Seq models.
## Quantize Seq2Seq models

The `ORTQuantizer` currently doesn't support multi-file models, like `ORTModelForSeq2SeqLM`. If you want to quantize a Seq2Seq model, you have to quantize each model's component individually using the `ORTQuantizer` class. Currently, only dynamic quantization is supported for Seq2Seq model.

Expand Down
38 changes: 14 additions & 24 deletions examples/onnxruntime/optimization/multiple-choice/run_swag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
import torch
import transformers
from datasets import load_dataset
from transformers import HfArgumentParser, TrainingArguments
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version

from optimum.onnxruntime import ORTModel, ORTOptimizer
from optimum.onnxruntime import ORTModelForMultipleChoice, ORTOptimizer
from optimum.onnxruntime.configuration import OptimizationConfig, ORTConfig
from optimum.onnxruntime.model import ORTModel


# Will error if the minimal version of Transformers is not installed. The version of transformers must be >= 4.19.0
Expand Down Expand Up @@ -132,10 +133,6 @@ class OptimizationArguments:
Arguments pertaining to what type of optimization we are going to apply on the model.
"""

opset: Optional[int] = field(
default=None,
metadata={"help": "ONNX opset version to export the model with."},
)
optimization_level: Optional[int] = field(
default=1,
metadata={
Expand Down Expand Up @@ -224,7 +221,9 @@ def main():

os.makedirs(training_args.output_dir, exist_ok=True)
model_path = os.path.join(training_args.output_dir, "model.onnx")
optimized_model_path = os.path.join(training_args.output_dir, "model-optimized.onnx")
optimized_model_path = os.path.join(training_args.output_dir, "model_optimized.onnx")

tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name or model_args.model_name_or_path)

# Create the optimization configuration containing all the optimization parameters
optimization_config = OptimizationConfig(
Expand All @@ -233,22 +232,14 @@ def main():
optimize_for_gpu=optim_args.optimize_for_gpu,
)

# Create the optimizer
optimizer = ORTOptimizer.from_pretrained(
model_args.model_name_or_path, feature="multiple-choice", opset=optim_args.opset
)
# Export the model
model = ORTModelForMultipleChoice.from_pretrained(model_args.model_name_or_path, from_transformers=True)

# Export the optimized model
optimizer.export(
onnx_model_path=model_path,
onnx_optimized_model_output_path=optimized_model_path,
optimization_config=optimization_config,
)
# Create the optimizer
optimizer = ORTOptimizer.from_pretrained(model)

# Create the ONNX Runtime configuration summarizing all the parameters related to ONNX IR export and optimization
ort_config = ORTConfig(opset=optimizer.opset, optimization=optimization_config)
# Save the configuration
ort_config.save_pretrained(training_args.output_dir)
# Optimize the model
optimizer.optimize(optimization_config=optimization_config, save_dir=training_args.output_dir)

if training_args.do_eval:
# Prepare the dataset downloading, preprocessing and metric creation to perform the evaluation and / or the
Expand Down Expand Up @@ -313,7 +304,7 @@ def preprocess_function(examples, tokenizer: PreTrainedTokenizerBase):
# Preprocess the evaluation dataset
with training_args.main_process_first(desc="Running tokenizer on the validation dataset"):
eval_dataset = eval_dataset.map(
partial(preprocess_function, tokenizer=optimizer.preprocessor),
partial(preprocess_function, tokenizer=tokenizer),
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
Expand All @@ -330,15 +321,14 @@ def compute_metrics(eval_predictions):

ort_model = ORTModel(
optimized_model_path,
optimizer._onnx_config,
execution_provider=optim_args.execution_provider,
compute_metrics=compute_metrics,
label_names=["label"],
)
outputs = ort_model.evaluation_loop(eval_dataset)

# Save evaluation metrics
with open(os.path.join(training_args.output_dir, f"eval_results.json"), "w") as f:
with open(os.path.join(training_args.output_dir, "eval_results.json"), "w") as f:
json.dump(outputs.metrics, f, indent=4, sort_keys=True)


Expand Down
56 changes: 24 additions & 32 deletions examples/onnxruntime/optimization/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@
import datasets
import transformers
from datasets import load_dataset, load_metric
from transformers import EvalPrediction, HfArgumentParser, PreTrainedTokenizer, TrainingArguments
from transformers import AutoTokenizer, EvalPrediction, HfArgumentParser, PreTrainedTokenizer, TrainingArguments
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

from optimum.onnxruntime import ORTModelForQuestionAnswering, ORTOptimizer
from optimum.onnxruntime.configuration import OptimizationConfig, ORTConfig
from optimum.onnxruntime.model import ORTModel
from optimum.onnxruntime.optimization import ORTOptimizer
from trainer_qa import QuestionAnsweringTrainer
from utils_qa import postprocess_qa_predictions

Expand Down Expand Up @@ -81,10 +81,6 @@ class ModelArguments:
"with private models)."
},
)
execution_provider: str = field(
default="CPUExecutionProvider",
metadata={"help": "ONNX Runtime execution provider to use for inference."},
)


@dataclass
Expand Down Expand Up @@ -204,10 +200,6 @@ class OptimizationArguments:
Arguments pertaining to what type of optimization we are going to apply on the model.
"""

opset: Optional[int] = field(
default=None,
metadata={"help": "ONNX opset version to export the model with."},
)
optimization_level: Optional[int] = field(
default=1,
metadata={
Expand All @@ -233,6 +225,10 @@ class OptimizationArguments:
"GPU or CPU only when optimization_level > 1."
},
)
execution_provider: str = field(
default="CPUExecutionProvider",
metadata={"help": "ONNX Runtime execution provider to use for inference."},
)


def main():
Expand Down Expand Up @@ -264,7 +260,7 @@ def main():
if (
optim_args.optimization_level > 1
and optim_args.optimize_for_gpu
and model_args.execution_provider == "CPUExecutionProvider"
and optim_args.execution_provider == "CPUExecutionProvider"
):
raise ValueError(
f"Optimization level is set at {optim_args.optimization_level} and "
Expand All @@ -275,7 +271,7 @@ def main():
if (
optim_args.optimization_level > 1
and not optim_args.optimize_for_gpu
and model_args.execution_provider == "CUDAExecutionProvider"
and optim_args.execution_provider == "CUDAExecutionProvider"
):
raise ValueError(
f"Optimization level is set at {optim_args.optimization_level} and "
Expand All @@ -293,7 +289,9 @@ def main():

os.makedirs(training_args.output_dir, exist_ok=True)
model_path = os.path.join(training_args.output_dir, "model.onnx")
optimized_model_path = os.path.join(training_args.output_dir, "model-optimized.onnx")
optimized_model_path = os.path.join(training_args.output_dir, "model_optimized.onnx")

tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name or model_args.model_name_or_path)

# Create the optimization configuration containing all the optimization parameters
optimization_config = OptimizationConfig(
Expand All @@ -302,22 +300,14 @@ def main():
optimize_for_gpu=optim_args.optimize_for_gpu,
)

# Create the optimizer
optimizer = ORTOptimizer.from_pretrained(
model_args.model_name_or_path, feature="question-answering", opset=optim_args.opset
)
# Export the model
model = ORTModelForQuestionAnswering.from_pretrained(model_args.model_name_or_path, from_transformers=True)

# Export the optimized model
optimizer.export(
onnx_model_path=model_path,
onnx_optimized_model_output_path=optimized_model_path,
optimization_config=optimization_config,
)
# Create the optimizer
optimizer = ORTOptimizer.from_pretrained(model)

# Create the ONNX Runtime configuration summarizing all the parameters related to ONNX IR export and optimization
ort_config = ORTConfig(opset=optimizer.opset, optimization=optimization_config)
# Save the configuration
ort_config.save_pretrained(training_args.output_dir)
# Optimize the model
optimizer.optimize(optimization_config=optimization_config, save_dir=training_args.output_dir)

# Prepare the dataset downloading, preprocessing and metric creation to perform the evaluation and / or the
# prediction step(s)
Expand Down Expand Up @@ -456,7 +446,7 @@ def compute_metrics(p: EvalPrediction):
if data_args.max_eval_samples is not None:
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
eval_dataset = eval_examples.map(
partial(prepare_validation_features, tokenizer=optimizer.preprocessor),
partial(prepare_validation_features, tokenizer=tokenizer),
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
Expand All @@ -469,9 +459,9 @@ def compute_metrics(p: EvalPrediction):

ort_model = ORTModel(
optimized_model_path,
optimizer._onnx_config,
execution_provider=model_args.execution_provider,
execution_provider=optim_args.execution_provider,
compute_metrics=compute_metrics,
label_names=["start_positions", "end_positions"],
)
outputs = ort_model.evaluation_loop(eval_dataset)
predictions = post_processing_function(eval_examples, eval_dataset, outputs.predictions)
Expand All @@ -492,7 +482,7 @@ def compute_metrics(p: EvalPrediction):
if data_args.max_predict_samples is not None:
predict_examples = predict_examples.select(range(data_args.max_predict_samples))
predict_dataset = predict_examples.map(
partial(prepare_validation_features, tokenizer=optimizer.preprocessor),
partial(prepare_validation_features, tokenizer=tokenizer),
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
Expand All @@ -504,7 +494,9 @@ def compute_metrics(p: EvalPrediction):
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))

ort_model = ORTModel(
optimized_model_path, optimizer._onnx_config, execution_provider=model_args.execution_provider
optimized_model_path,
execution_provider=optim_args.execution_provider,
label_names=["start_positions", "end_positions"],
)
outputs = ort_model.evaluation_loop(predict_dataset)
predictions = post_processing_function(predict_examples, predict_dataset, outputs.predictions)
Expand Down
Loading