Skip to content

Commit

Permalink
Onnx fix test (#10663)
Browse files Browse the repository at this point in the history
* Allow to pass kwargs to model's from_pretrained when using pipeline.

* Disable the use of past_keys_values for GPT2 when exporting to ONNX.

* style

* Remove comment.

* Appease the documentation gods

* Fix style

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
  • Loading branch information
mfuntowicz and LysandreJik authored Mar 11, 2021
1 parent a637ae0 commit 3ab6820
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
10 changes: 7 additions & 3 deletions src/transformers/convert_graph_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
return input_vars, output_names, dynamic_axes, tokens


def load_graph_from_args(pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
def load_graph_from_args(
pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs
) -> Pipeline:
"""
Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model
Expand All @@ -248,7 +250,7 @@ def load_graph_from_args(pipeline_name: str, framework: str, model: str, tokeniz
print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})")

# Allocate tokenizer and model
return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework)
return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)


def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):
Expand Down Expand Up @@ -335,6 +337,7 @@ def convert(
tokenizer: Optional[str] = None,
use_external_format: bool = False,
pipeline_name: str = "feature-extraction",
**model_kwargs
):
"""
Convert the pipeline object to the ONNX Intermediate Representation (IR) format
Expand All @@ -347,14 +350,15 @@ def convert(
tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided
use_external_format: Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)
pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)
model_kwargs: Keyword arguments to be forwarded to the model constructor
Returns:
"""
print(f"ONNX opset version set to: {opset}")

# Load the pipeline
nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer)
nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)

if not output.parent.exists():
print(f"Creating folder {output.parent}")
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def pipeline(
framework: Optional[str] = None,
revision: Optional[str] = None,
use_fast: bool = True,
model_kwargs: Dict[str, Any] = {},
**kwargs
) -> Pipeline:
"""
Expand Down Expand Up @@ -307,6 +308,9 @@ def pipeline(
artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
model_kwargs:
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
**model_kwargs)` function.
kwargs:
Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
corresponding pipeline class for possible values).
Expand Down Expand Up @@ -383,7 +387,6 @@ def pipeline(
# Instantiate model if needed
if isinstance(model, str):
# Handle transparent TF/PT model conversion
model_kwargs = {}
if framework == "pt" and model.endswith(".h5"):
model_kwargs["from_tf"] = True
logger.warning(
Expand Down
26 changes: 15 additions & 11 deletions tests/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,23 @@ def forward(self, input_ids, some_other_args, token_type_ids, attention_mask):


class OnnxExportTestCase(unittest.TestCase):
MODEL_TO_TEST = ["bert-base-cased", "gpt2", "roberta-base"]
MODEL_TO_TEST = [
# (model_name, model_kwargs)
("bert-base-cased", {}),
("gpt2", {"use_cache": False}), # We don't support exporting GPT2 past keys anymore
]

@require_tf
@slow
def test_export_tensorflow(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "tf", 12)
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "tf", 12, **model_kwargs)

@require_torch
@slow
def test_export_pytorch(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "pt", 12)
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "pt", 12, **model_kwargs)

@require_torch
@slow
Expand All @@ -71,8 +75,8 @@ def test_export_custom_bert_model(self):
@require_tf
@slow
def test_quantize_tf(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
path = self._test_export(model, "tf", 12)
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
path = self._test_export(model, "tf", 12, **model_kwargs)
quantized_path = quantize(Path(path))

# Ensure the actual quantized model is not bigger than the original one
Expand All @@ -82,15 +86,15 @@ def test_quantize_tf(self):
@require_torch
@slow
def test_quantize_pytorch(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
path = self._test_export(model, "pt", 12)
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
path = self._test_export(model, "pt", 12, **model_kwargs)
quantized_path = quantize(path)

# Ensure the actual quantized model is not bigger than the original one
if quantized_path.stat().st_size >= Path(path).stat().st_size:
self.fail("Quantized model is bigger than initial ONNX model")

def _test_export(self, model, framework, opset, tokenizer=None):
def _test_export(self, model, framework, opset, tokenizer=None, **model_kwargs):
try:
# Compute path
with TemporaryDirectory() as tempdir:
Expand All @@ -101,7 +105,7 @@ def _test_export(self, model, framework, opset, tokenizer=None):
path.parent.rmdir()

# Export
convert(framework, model, path, opset, tokenizer)
convert(framework, model, path, opset, tokenizer, **model_kwargs)

return path
except Exception as e:
Expand Down

0 comments on commit 3ab6820

Please sign in to comment.