Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimize and quantize command CLI #700

Merged
merged 16 commits into from
Feb 2, 2023
36 changes: 35 additions & 1 deletion docs/source/onnxruntime/usage_guides/optimization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,41 @@ specific language governing permissions and limitations under the License.
🤗 Optimum provides an `optimum.onnxruntime` package that enables you to apply graph optimization on many model hosted on the 🤗 hub using the [ONNX Runtime](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers) model optimization tool.


## Creating an `ORTOptimizer`
## Optimizing a model to be used with Optimum's CLI

The Optimum ONNX Runtime optimization tool can be used through Optimum command-line interface:

```bash
optimum-cli onnxruntime optimize --help
usage: optimum-cli <command> [<args>] onnxruntime optimize [-h] --onnx_model ONNX_MODEL [-o OUTPUT] (-O1 | -O2 | -O3 | -O4)

options:
-h, --help show this help message and exit
-O1 Basic general optimizations (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).
-O2 Basic and extended general optimizations, transformers-specific fusions (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).
-O3 Same as O2 with Gelu approximation (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).
-O4 Same as O3 with mixed precision (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).
-c, --config `ORTConfig` file to use to optimize the model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add --config details here

Required arguments:
--onnx_model ONNX_MODEL
Path indicating where the ONNX models to optimize are located.

Optional arguments:
-o OUTPUT, --output OUTPUT
Path indicating the directory where to store generated ONNX model. (defaults to --onnx_model value).
```

Optimizing an ONNX model can be done as follows:

```bash
optimum-cli onnxruntime optimize --onnx_model onnx_model_location/ -O1
```

This optimize all the ONNX files in `onnx_model_location` with the basic general optimizations. The optimized models will be created in the same directory by default unless the `--output` argument is specified.


## Optimizing a model to be used with Optimum's `ORTOptimizer`

The [`~onnxruntime.ORTOptimizer`] class is used to optimize your ONNX model. The class can be initialized using the `from_pretrained()` method, which supports different checkpoint formats.

Expand Down
34 changes: 34 additions & 0 deletions docs/source/onnxruntime/usage_guides/quantization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,40 @@ explains the main concepts that you will be using when performing quantization w

</Tip>

## Quantizing a model to be used with Optimum's CLI

The Optimum ONNX Runtime quantization tool can be used through Optimum command-line interface:

```bash
optimum-cli onnxruntime quantize --help
usage: optimum-cli <command> [<args>] onnxruntime quantize [-h] --onnx_model ONNX_MODEL [-o OUTPUT] (--arm64 | --avx2 | --avx512 | --avx512_vnni | --tensorrt)

options:
-h, --help show this help message and exit
--arm64 Quantization for the ARM64 architecture.
--avx2 Quantization with AVX-2 instructions.
--avx512 Quantization with AVX-512 instructions.
--avx512_vnni Quantization with AVX-512 and VNNI instructions.
--tensorrt Quantization for NVIDIA TensorRT optimizer.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add --config details here

-c, --config `ORTConfig` file to use to optimize the model.

Required arguments:
--onnx_model ONNX_MODEL
Path indicating where the ONNX models to quantize are located.

Optional arguments:
-o OUTPUT, --output OUTPUT
Path indicating the directory where to store generated ONNX model. (defaults to --onnx_model value).
```

Quantizing an ONNX model can be done as follows:

```bash
optimum-cli onnxruntime quantize --onnx_model onnx_model_location/ --avx512
```

This quantize all the ONNX files in `onnx_model_location` with the AVX-512 instructions. The quantized models will be created in the same directory by default unless the `--output` argument is specified.

## Creating an `ORTQuantizer`

The [`~optimum.onnxruntime.ORTQuantizer`] class is used to quantize your ONNX model. The class can be initialized using
Expand Down
33 changes: 33 additions & 0 deletions optimum/commands/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import sys
from argparse import ArgumentParser

from .. import BaseOptimumCLICommand
from .optimize import ONNXRuntimmeOptimizeCommand, parse_args_onnxruntime_optimize
from .quantize import ONNXRuntimmeQuantizeCommand, parse_args_onnxruntime_quantize


def onnxruntime_optimize_factory(args):
return ONNXRuntimmeOptimizeCommand(args)


def onnxruntime_quantize_factory(args):
return ONNXRuntimmeQuantizeCommand(args)


class ONNXRuntimeCommand(BaseOptimumCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
onnxruntime_parser = parser.add_parser("onnxruntime", help="ONNX Runtime optimize and quantize utilities.")
onnxruntime_sub_parsers = onnxruntime_parser.add_subparsers()

optimize_parser = onnxruntime_sub_parsers.add_parser("optimize", help="Optimize ONNX models.")
quantize_parser = onnxruntime_sub_parsers.add_parser("quantize", help="Dynammic quantization for ONNX models.")

parse_args_onnxruntime_optimize(optimize_parser)
parse_args_onnxruntime_quantize(quantize_parser)

optimize_parser.set_defaults(func=onnxruntime_optimize_factory)
quantize_parser.set_defaults(func=onnxruntime_quantize_factory)

def run(self):
raise NotImplementedError()
80 changes: 80 additions & 0 deletions optimum/commands/onnxruntime/optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from pathlib import Path

from ...onnxruntime.configuration import AutoOptimizationConfig, ORTConfig
from ...onnxruntime.optimization import ORTOptimizer


def parse_args_onnxruntime_optimize(parser):
required_group = parser.add_argument_group("Required arguments")
required_group.add_argument(
"--onnx_model",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make that just a regular argument:

Suggested change
"--onnx_model",
"onnx_model",

That way we can do:

optimum-cli onnxruntime optimize path_to_my_model -O2 my_output

Which seems less heavy IMO.

@fxmarty and maybe we should do the same for exporters, wdty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can do that for the onnx_model argument.

Copy link
Contributor

@fxmarty fxmarty Jan 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having several unnamed arguments makes things less readable and error prone. I'm not in favor personally, but it's taste.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, let's keep it like that then!

type=Path,
required=True,
help="Path to the repository where the ONNX models to optimize are located.",
)

optional_group = parser.add_argument_group("Optional arguments")
optional_group.add_argument(
"-o",
"--output",
Comment on lines +18 to +19
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output is now optional and it is not possible to make it regular like for export because the command takes by default the onnx_model folder as default output. Unless this is not the behavior you would like?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it like that then!

type=Path,
help="Path to the directory where to store generated ONNX model. (defaults to --onnx_model value).",
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being able to provide predefined optimization configs is great.
Could we also add the possibility to provide a path where an ORTConfig is stored?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add this yes.

Copy link
Member

@michaelbenayoun michaelbenayoun Jan 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would allow more custom usage!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michaelbenayoun Do you mean ORTConfig or OptimizationConfig?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ORTConfig since those are the ones we push to the Hub.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

level_group = parser.add_mutually_exclusive_group(required=True)
level_group.add_argument(
"-O1",
action="store_true",
help="Basic general optimizations (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).",
)
level_group.add_argument(
"-O2",
action="store_true",
help="Basic and extended general optimizations, transformers-specific fusions (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).",
)
level_group.add_argument(
"-O3",
action="store_true",
help="Same as O2 with Gelu approximation (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).",
)
level_group.add_argument(
"-O4",
action="store_true",
help="Same as O3 with mixed precision (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).",
)
level_group.add_argument(
"-c",
"--config",
type=Path,
help="`ORTConfig` file to use to optimize the model.",
)


class ONNXRuntimmeOptimizeCommand:
def __init__(self, args):
self.args = args

def run(self):
if not self.args.output:
save_dir = self.args.onnx_model
else:
save_dir = self.args.output

file_names = [model.name for model in self.args.onnx_model.glob("*.onnx")]

optimizer = ORTOptimizer.from_pretrained(self.args.onnx_model, file_names)

if self.args.config:
optimization_config = ORTConfig
elif self.args.O1:
optimization_config = AutoOptimizationConfig.O1()
elif self.args.O2:
optimization_config = AutoOptimizationConfig.O2()
elif self.args.O3:
optimization_config = AutoOptimizationConfig.O3()
elif self.args.O4:
optimization_config = AutoOptimizationConfig.O4()
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
elif self.args.config:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same answer than below.

optimization_config = ORTConfig.from_pretained(self.args.config).optimization

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
raise ValueError("An optimization configuration must be provided, either by using the predefined optimization configurations (O1, O2, O3, O4) or by specifying the path to a custom ORTCOnfig")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is an exclusive group, at least one is mandatory, if not, an error is raised. Then this part is not useful.

optimizer.optimize(save_dir=save_dir, optimization_config=optimization_config)
77 changes: 77 additions & 0 deletions optimum/commands/onnxruntime/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from argparse import ArgumentParser
from pathlib import Path

from ...onnxruntime.configuration import AutoQuantizationConfig, ORTConfig
from ...onnxruntime.quantization import ORTQuantizer


def parse_args_onnxruntime_quantize(parser):
required_group = parser.add_argument_group("Required arguments")
required_group.add_argument(
"--onnx_model",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment

type=Path,
required=True,
help="Path to the repository where the ONNX models to quantize are located.",
)

optional_group = parser.add_argument_group("Optional arguments")
optional_group.add_argument(
"-o",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment

"--output",
type=Path,
help="Path to the directory where to store generated ONNX model. (defaults to --onnx_model value).",
)
optional_group.add_argument(
"--per_channel",
action="store_true",
help="Compute the quantization parameters on a per-channel basis.",
)

level_group = parser.add_mutually_exclusive_group(required=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as the optimization level, maybe we could add the possibility to specify a path to an ORTConfig.

level_group.add_argument("--arm64", action="store_true", help="Quantization for the ARM64 architecture.")
level_group.add_argument("--avx2", action="store_true", help="Quantization with AVX-2 instructions.")
level_group.add_argument("--avx512", action="store_true", help="Quantization with AVX-512 instructions.")
level_group.add_argument(
"--avx512_vnni", action="store_true", help="Quantization with AVX-512 and VNNI instructions."
)
level_group.add_argument("--tensorrt", action="store_true", help="Quantization for NVIDIA TensorRT optimizer.")
level_group.add_argument(
"-c",
"--config",
type=Path,
help="`ORTConfig` file to use to optimize the model.",
)


class ONNXRuntimmeQuantizeCommand:
def __init__(self, args):
self.args = args

def run(self):
if not self.args.output:
save_dir = self.args.onnx_model
else:
save_dir = self.args.output

quantizers = []

quantizers = [
ORTQuantizer.from_pretrained(save_dir, file_name=model.name)
for model in self.args.onnx_model.glob("*.onnx")
]

if self.args.arm64:
qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=self.args.per_channel)
elif self.args.avx2:
qconfig = AutoQuantizationConfig.avx2(is_static=False, per_channel=self.args.per_channel)
elif self.args.avx512:
qconfig = AutoQuantizationConfig.avx512(is_static=False, per_channel=self.args.per_channel)
elif self.args.avx512_vnni:
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=self.args.per_channel)
elif self.args.tensorrt:
qconfig = AutoQuantizationConfig.tensorrt(is_static=False, per_channel=self.args.per_channel)
else:
qconfig = ORTConfig.from_pretained(self.args.config).quantization

for q in quantizers:
q.quantize(save_dir=save_dir, quantization_config=qconfig)
2 changes: 2 additions & 0 deletions optimum/commands/optimum_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .env import EnvironmentCommand
from .export import ExportCommand
from .onnxruntime import ONNXRuntimeCommand


def main():
Expand All @@ -27,6 +28,7 @@ def main():
# Register commands
ExportCommand.register_subcommand(commands_parser)
EnvironmentCommand.register_subcommand(commands_parser)
ONNXRuntimeCommand.register_subcommand(commands_parser)

args = parser.parse_args()

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"packaging",
"numpy<1.24.0",
"huggingface_hub>=0.8.0",
"datasets",
]

TESTS_REQUIRE = ["pytest", "requests", "parameterized", "pytest-xdist", "Pillow", "sacremoses", "diffusers"]
Expand Down
48 changes: 43 additions & 5 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,55 @@ def test_helps_no_raise(self):
"optimum-cli export --help",
"optimum-cli export onnx --help",
"optimum-cli env --help",
"optimum-cli onnxruntime quantize --help",
"optimum-cli onnxruntime optimize --help",
]

for command in commands:
subprocess.run(command, shell=True, check=True)

def test_basic_commands(self):
def test_env_commands(self):
subprocess.run("optimum-cli env", shell=True, check=True)

def test_export_commands(self):
with tempfile.TemporaryDirectory() as tempdir:
commands = [
"optimum-cli env",
command = (
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-vision_perceiver_conv --task image-classification {tempdir}",
)
subprocess.run(command, shell=True, check=True)

def test_optimize_commands(self):
with tempfile.TemporaryDirectory() as tempdir:
# First export a tiny encoder, decoder only and encoder-decoder
export_commands = [
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-BertModel {tempdir}/encoder",
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-gpt2 {tempdir}/decoder",
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder",
]
optimize_commands = [
f"optimum-cli onnxruntime optimize --onnx_model {tempdir}/encoder -O1",
f"optimum-cli onnxruntime optimize --onnx_model {tempdir}/decoder -O1",
f"optimum-cli onnxruntime optimize --onnx_model {tempdir}/encoder-decoder -O1",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add test with custom config here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you already have an ORTConfig file for testing usage somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fxmarty @michaelbenayoun I have an issue for testing the ORTConfig parameter. Apparently the parameters ORTConfig.optimization and ORTConfig.quantization are dictionaries. How can I get them back to usual dataclass object? As the quantize and optimize methods expect to have objects and not dictionaries.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not find one so forget about i!

]

for export, optimize in zip(export_commands, optimize_commands):
subprocess.run(export, shell=True, check=True)
subprocess.run(optimize, shell=True, check=True)

def test_optimize_commands(self):
with tempfile.TemporaryDirectory() as tempdir:
# First export a tiny encoder, decoder only and encoder-decoder
export_commands = [
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-BertModel {tempdir}/encoder",
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-gpt2 {tempdir}/decoder",
f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder",
]
optimize_commands = [
f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder --avx2",
f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/decoder --avx2",
f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments

]

for command in commands:
subprocess.run(command, shell=True, check=True)
for export, optimize in zip(export_commands, optimize_commands):
subprocess.run(export, shell=True, check=True)
subprocess.run(optimize, shell=True, check=True)