Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MPNet ONNX export #691

Merged
merged 4 commits into from
Jan 13, 2023
Merged

Add MPNet ONNX export #691

merged 4 commits into from
Jan 13, 2023

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Jan 12, 2023

What does this PR do?

This PR adds the possibility to export MPNet models into the ONNX format.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@@ -111,6 +111,10 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
return {"input_ids": dynamic_axis, "attention_mask": dynamic_axis}


class MPNetOnnxConfig(DistilBertOnnxConfig):
DEFAULT_ONNX_OPSET = 12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my knowledge, is opset 11 failing for this model? If you want, you can add a comment as to why opset 12 is required.

Copy link
Contributor Author

@jplu jplu Jan 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the error when export with opset 11:

An error occured with the error message: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from here/model.onnx failed:This is an invalid model. Type Error: Type 'tensor(int64)' of input parameter (/encoder/Add_1_output_0) of operator (Min) in node (/encoder/Min) is invalid..

Apparently the Min operator is not compliant with int64 tensors (as far as what I understand from this message)

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution, looks good to me as long as the tests pass!

Edit: you may need to run make style

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 12, 2023

The documentation is not available anymore as the PR was closed or merged.

@jplu
Copy link
Contributor Author

jplu commented Jan 13, 2023

Looks like there an issue with the multiple_choice task:

Traceback (most recent call last):
  File "/opt/hostedtoolcache/Python/3.8.15/x64/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/hostedtoolcache/Python/3.8.15/x64/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/opt/hostedtoolcache/Python/3.8.15/x64/lib/python3.8/site-packages/optimum/exporters/onnx/__main__.py", line 194, in <module>
    main()
  File "/opt/hostedtoolcache/Python/3.8.15/x64/lib/python3.8/site-packages/optimum/exporters/onnx/__main__.py", line 74, in main
    model = TasksManager.get_model_from_task(task, args.model, framework=args.framework, cache_dir=args.cache_dir)
  File "/opt/hostedtoolcache/Python/3.8.15/x64/lib/python3.8/site-packages/optimum/exporters/tasks.py", line 963, in get_model_from_task
    model = model_class.from_pretrained(model_name_or_path, **kwargs)
  File "/opt/hostedtoolcache/Python/3.8.15/x64/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py", line 463, in from_pretrained
    return model_class.from_pretrained(
  File "/opt/hostedtoolcache/Python/3.8.15/x64/lib/python3.8/site-packages/transformers/modeling_utils.py", line 2379, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/opt/hostedtoolcache/Python/3.8.15/x64/lib/python3.8/site-packages/transformers/modeling_utils.py", line 2695, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for MPNetForMultipleChoice:
	size mismatch for classifier.weight: copying a param with shape torch.Size([2, 64]) from checkpoint, the shape in current model is torch.Size([1, 64]).
	size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
	You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

@fxmarty any idea how to fix this? Or should I remove this tasks from the task manager?

@fxmarty
Copy link
Contributor

fxmarty commented Jan 13, 2023

@jplu It's probably just an issue with hf-internal-testing/tiny-random-mpnet. Could you try using https://huggingface.co/hf-internal-testing/tiny-random-MPNetModel ?

You should be able to run the failing test specifying -k "test_exporters_cli_pytorch_273_mpnet_multiple_choice" in pytest

@jplu
Copy link
Contributor Author

jplu commented Jan 13, 2023

The test is skiped:

pytest tests/exporters/test_exporters_onnx_cli.py -k "test_exporters_cli_pytorch_273_mpnet_multiple_choice"
=================================================================================================================================================================================== test session starts ====================================================================================================================================================================================
platform linux -- Python 3.10.8, pytest-7.2.0, pluggy-1.0.0
rootdir: /home/jplu/dev/perso/optimum, configfile: setup.cfg
collected 413 items / 412 deselected / 1 selected

tests/exporters/test_exporters_onnx_cli.py s                                                                                                                                                                                                                                                                                                                                         [100%]

============================================================================================================================================================================ 1 skipped, 412 deselected in 1.41s ============================================================================================================================================================================

@fxmarty
Copy link
Contributor

fxmarty commented Jan 13, 2023

Hum that's weird. What if you just pass -k "mpnet"?

@jplu
Copy link
Contributor Author

jplu commented Jan 13, 2023

They are all skipped:

=================================================================================================================================================================================== test session starts ====================================================================================================================================================================================
platform linux -- Python 3.10.8, pytest-7.2.0, pluggy-1.0.0
rootdir: /home/jplu/dev/perso/optimum, configfile: setup.cfg
collected 413 items / 406 deselected / 7 selected

tests/exporters/test_exporters_onnx_cli.py sssssss                                                                                                                                                                                                                                                                                                                                   [100%]

============================================================================================================================================================================ 7 skipped, 406 deselected in 2.62s ============================================================================================================================================================================

@fxmarty
Copy link
Contributor

fxmarty commented Jan 13, 2023

I can run the mpnet tests locally on your branch with pytest tests/exporters/test_exporters_onnx_cli.py -k "mpnet" --exitfirst -s. Not sure why they are skipped for you?

(edit: the tests pass well with hf-internal-testing/tiny-random-MPNetModel)

@fxmarty
Copy link
Contributor

fxmarty commented Jan 13, 2023

It's all good, thanks for the addition!

@fxmarty fxmarty merged commit 2911a91 into huggingface:main Jan 13, 2023
@jplu jplu deleted the add-mpnet branch January 17, 2023 10:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants