Skip to content

v1.6.0: Optimum CLI, Stable Diffusion ONNX export, BetterTransformer & ONNX support for more architectures

Compare
Choose a tag to compare
@fxmarty fxmarty released this 23 Dec 15:30

Optimum CLI

The Optimum command line interface is introduced, and is now the official entrypoint for the ONNX export. Example commands:

optimum-cli --help
optimum-cli export onnx --help
optimum-cli export onnx --model bert-base-uncased --task sequence-classification bert_onnx/

Stable Diffusion ONNX export

Optimum now supports the ONNX export of stable diffusion models from the diffusers library:

optimum-cli export onnx --model runwayml/stable-diffusion-v1-5 sd_v15_onnx/

BetterTransformer support for more architectures

BetterTransformer integration includes new models in this release: CLIP, RemBERT, mBART, ViLT, FSMT

The complete list of supported models is available in the documentation.

ONNX export for more architectures

The ONNX export now supports Swin, MobileNet-v1, MobileNet-v2.

Extended ONNX export for encoder-decoder and decoder models

Encoder-decoder or decoder-only models normally making use of the generate() method in transformers can now be exported in several files using the --for-ort argument:

optimum-cli export onnx --model t5-small --task seq2seq-lm-with-past --for-ort t5_small_onnx

yielding:

.
└── t5_small_onnx
    ├── config.json
    ├── decoder_model.onnx
    ├── decoder_with_past_model.onnx
    ├── encoder_model.onnx
    ├── special_tokens_map.json
    ├── spiece.model
    ├── tokenizer_config.json
    └── tokenizer.json

Passing --for-ort, exported models are expected to be loadable directly into ORTModel.

  • Add ort export in exporters for encoder-decoder models by @mht-sharma in #497
  • Support decoder generated with --for-ort from optimum.exporters.onnx in ORTDecoder by @fxmarty in #554

Support for ONNX models with external data at export, optimization, quantization

The ONNX export from PyTorch normally creates external data in case the exported model is larger than 2 GB. This release introduces a better support for the export and use of large models, writting all external data into a .onnx_data file if necessary.

  • Handling ONNX models with external data by @NouamaneTazi in #586
  • Improve the compatibility dealing with large ONNX proto in ORTOptimizer and ORTQuantizer by @JingyaHuang in #332

ONNX Runtime API improvement

Various improvements to allow for a better user experience in the ONNX Runtime integration:

  • ORTModel, ORTModelDecoder and ORTModelForConditionalGeneration can now load any ONNX model files regardless of their names, allowing to load optimized and quantized models without having to specify a file name argument.

  • ORTModel.from_pretrained() with from_transformers=True now downloads and loads the model in a temporary directory instead of the cache, which was not a right place to store it.

  • ORTQuantizer.save_pretrained() now saves the model configuration and the preprocessor, making the exported directory usable end-to-end.

  • ORTOptimizer.save_pretrained() now saves the preprocessor, making the exported directory usable end-to-end.

  • ONNX Runtime integration API improvement by @michaelbenayoun in #515

Custom shapes support at ONNX export

The shape of the example input to provide for the export to ONNX can be overridden in case the validity of the ONNX model is sensitive to the shape used during the export.

Read more: optimum-cli export onnx --help

  • Support custom shapes for dummy inputs by @fxmarty in #522
  • Support for custom input shapes in exporters onnx by @fxmarty in #575

Enable use_cache=True for ORTModelForCausalLM

Reusing past key values for models using ORTModelForCausalLM (e.g. gpt2) is now possible using use_cache=True, avoiding to recompute them at each iteration of the decoding:

from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = ORTModelForCausalLM.from_pretrained("gpt2", from_transformers=True, use_cache=True)

inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt")

gen_tokens = model.generate(**inputs)
tokenizer.batch_decode(gen_tokens)
  • Enable past_key_values for ORTModelForCausalLM by @echarlaix in #326

IO binding support for ORTModelForCustomTasks

ORTModelForCustomTasks now supports IO Binding when using CUDAExecutionProvider.

Experimental support to merge ONNX decoder with/without past key values

Along with --for-ort, when passing --task causal-lm-with-past , --task seq2seq-with-past or --task speech2seq-lm-with-past during the ONNX export exports two models: one not using the previously computed keys/values, and one using them.

An experimental support is introduced to merge the two models in one. Example:

optimum-cli export onnx --model t5-small --task seq2seq-lm-with-past --for-ort t5_onnx/
import onnx
from optimum.onnx import merge_decoders

decoder = onnx.load("t5_onnx/decoder_model.onnx")
decoder_with_past = onnx.load("t5_onnx/decoder_with_past_model.onnx")

merged_model = merge_decoders(decoder, decoder_with_past)
onnx.save(merged_model, "t5_onnx/decoder_merged_model.onnx")

Major bugs fixed

Other changes, bugfixes and improvements

Full Changelog: v1.5.2...v1.6.0

Significant community contributions

The following contributors have made significant changes to the library over the last release: