-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Determine framework automatically before ONNX export #18615
Merged
patrickvonplaten
merged 15 commits into
huggingface:main
from
rachthree:onnx-export-driver-auto-framework
Aug 25, 2022
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
c7d1cb4
Automatic detection for framework to use when exporting to ONNX
rachthree f45381a
Log message change
rachthree 0c73254
Merge branch 'main' of github.com:rachthree/transformers into onnx-ex…
rachthree b832988
Incorporating PR comments, adding unit test
rachthree d2f78c8
Merge branch 'onnx-export-driver-auto-framework' of github.com:rachth…
rachthree ce96dee
Adding tf for pip install for run_tests_onnxruntime CI
rachthree 67416f2
Restoring past changes to circleci yaml and test_onnx_v2.py, tests mo…
rachthree 695c72c
Merge branch 'main' of github.com:rachthree/transformers into onnx-ex…
rachthree cfcae03
Fixup
rachthree 8da5990
Adding test to fetcher
rachthree 63198fd
Updating circleci config to log more
rachthree 8787399
Changing test class name
rachthree 6a619ff
Comment typo fix in tests/onnx/test_features.py
rachthree 6bd7477
Moving torch_str/tf_str to self.framework_pt/tf
rachthree d8f3804
Remove -rA flag in circleci config
rachthree File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
import os | ||
from functools import partial, reduce | ||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union | ||
|
||
import transformers | ||
|
||
from .. import PretrainedConfig, is_tf_available, is_torch_available | ||
from ..utils import logging | ||
from ..utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging | ||
from .config import OnnxConfig | ||
|
||
|
||
|
@@ -557,9 +558,59 @@ def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type: | |
) | ||
return task_to_automodel[task] | ||
|
||
@staticmethod | ||
def determine_framework(model: str, framework: str = None) -> str: | ||
""" | ||
Determines the framework to use for the export. | ||
|
||
The priority is in the following order: | ||
1. User input via `framework`. | ||
2. If local checkpoint is provided, use the same framework as the checkpoint. | ||
3. Available framework in environment, with priority given to PyTorch | ||
|
||
Args: | ||
model (`str`): | ||
The name of the model to export. | ||
framework (`str`, *optional*, defaults to `None`): | ||
The framework to use for the export. See above for priority if none provided. | ||
|
||
Returns: | ||
The framework to use for the export. | ||
|
||
""" | ||
if framework is not None: | ||
return framework | ||
|
||
framework_map = {"pt": "PyTorch", "tf": "TensorFlow"} | ||
exporter_map = {"pt": "torch", "tf": "tf2onnx"} | ||
|
||
if os.path.isdir(model): | ||
if os.path.isfile(os.path.join(model, WEIGHTS_NAME)): | ||
framework = "pt" | ||
elif os.path.isfile(os.path.join(model, TF2_WEIGHTS_NAME)): | ||
framework = "tf" | ||
else: | ||
raise FileNotFoundError( | ||
"Cannot determine framework from given checkpoint location." | ||
f" There should be a {WEIGHTS_NAME} for PyTorch" | ||
f" or {TF2_WEIGHTS_NAME} for TensorFlow." | ||
) | ||
logger.info(f"Local {framework_map[framework]} model found.") | ||
else: | ||
if is_torch_available(): | ||
framework = "pt" | ||
elif is_tf_available(): | ||
framework = "tf" | ||
else: | ||
raise EnvironmentError("Neither PyTorch nor TensorFlow found in environment. Cannot export to ONNX.") | ||
|
||
logger.info(f"Framework not requested. Using {exporter_map[framework]} to export to ONNX.") | ||
|
||
return framework | ||
|
||
@staticmethod | ||
def get_model_from_feature( | ||
feature: str, model: str, framework: str = "pt", cache_dir: str = None | ||
feature: str, model: str, framework: str = None, cache_dir: str = None | ||
) -> Union["PreTrainedModel", "TFPreTrainedModel"]: | ||
""" | ||
Attempts to retrieve a model from a model's name and the feature to be enabled. | ||
|
@@ -569,20 +620,24 @@ def get_model_from_feature( | |
The feature required. | ||
model (`str`): | ||
The name of the model to export. | ||
framework (`str`, *optional*, defaults to `"pt"`): | ||
The framework to use for the export. | ||
framework (`str`, *optional*, defaults to `None`): | ||
The framework to use for the export. See `FeaturesManager.determine_framework` for the priority should | ||
none be provided. | ||
|
||
Returns: | ||
The instance of the model. | ||
|
||
""" | ||
framework = FeaturesManager.determine_framework(model, framework) | ||
model_class = FeaturesManager.get_model_class_for_feature(feature, framework) | ||
try: | ||
model = model_class.from_pretrained(model, cache_dir=cache_dir) | ||
except OSError: | ||
if framework == "pt": | ||
logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice idea to log these steps for the user! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! It helped me figure out the behavior, so hope it's helpful for others! |
||
model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir) | ||
else: | ||
logger.info("Loading PyTorch model in TensorFlow before exporting to ONNX.") | ||
model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir) | ||
return model | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from tempfile import TemporaryDirectory | ||
from unittest import TestCase | ||
from unittest.mock import MagicMock, patch | ||
|
||
from transformers import AutoModel, TFAutoModel | ||
from transformers.onnx import FeaturesManager | ||
from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch | ||
|
||
|
||
@require_torch | ||
@require_tf | ||
class DetermineFrameworkTest(TestCase): | ||
""" | ||
Test `FeaturesManager.determine_framework` | ||
""" | ||
|
||
def setUp(self): | ||
self.test_model = SMALL_MODEL_IDENTIFIER | ||
self.framework_pt = "pt" | ||
self.framework_tf = "tf" | ||
|
||
def _setup_pt_ckpt(self, save_dir): | ||
model_pt = AutoModel.from_pretrained(self.test_model) | ||
model_pt.save_pretrained(save_dir) | ||
|
||
def _setup_tf_ckpt(self, save_dir): | ||
model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True) | ||
model_tf.save_pretrained(save_dir) | ||
|
||
def test_framework_provided(self): | ||
""" | ||
Ensure the that the provided framework is returned. | ||
""" | ||
mock_framework = "mock_framework" | ||
|
||
# Framework provided - return whatever the user provides | ||
result = FeaturesManager.determine_framework(self.test_model, mock_framework) | ||
self.assertEqual(result, mock_framework) | ||
|
||
# Local checkpoint and framework provided - return provided framework | ||
# PyTorch checkpoint | ||
with TemporaryDirectory() as local_pt_ckpt: | ||
self._setup_pt_ckpt(local_pt_ckpt) | ||
result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework) | ||
self.assertEqual(result, mock_framework) | ||
|
||
# TensorFlow checkpoint | ||
with TemporaryDirectory() as local_tf_ckpt: | ||
self._setup_tf_ckpt(local_tf_ckpt) | ||
result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework) | ||
self.assertEqual(result, mock_framework) | ||
|
||
def test_checkpoint_provided(self): | ||
""" | ||
Ensure that the determined framework is the one used for the local checkpoint. | ||
|
||
For the functionality to execute, local checkpoints are provided but framework is not. | ||
""" | ||
# PyTorch checkpoint | ||
with TemporaryDirectory() as local_pt_ckpt: | ||
self._setup_pt_ckpt(local_pt_ckpt) | ||
result = FeaturesManager.determine_framework(local_pt_ckpt) | ||
self.assertEqual(result, self.framework_pt) | ||
|
||
# TensorFlow checkpoint | ||
with TemporaryDirectory() as local_tf_ckpt: | ||
self._setup_tf_ckpt(local_tf_ckpt) | ||
result = FeaturesManager.determine_framework(local_tf_ckpt) | ||
self.assertEqual(result, self.framework_tf) | ||
|
||
# Invalid local checkpoint | ||
with TemporaryDirectory() as local_invalid_ckpt: | ||
with self.assertRaises(FileNotFoundError): | ||
result = FeaturesManager.determine_framework(local_invalid_ckpt) | ||
|
||
def test_from_environment(self): | ||
""" | ||
Ensure that the determined framework is the one available in the environment. | ||
|
||
For the functionality to execute, framework and local checkpoints are not provided. | ||
""" | ||
# Framework not provided, hub model is used (no local checkpoint directory) | ||
# TensorFlow not in environment -> use PyTorch | ||
mock_tf_available = MagicMock(return_value=False) | ||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available): | ||
result = FeaturesManager.determine_framework(self.test_model) | ||
self.assertEqual(result, self.framework_pt) | ||
|
||
# PyTorch not in environment -> use TensorFlow | ||
mock_torch_available = MagicMock(return_value=False) | ||
with patch("transformers.onnx.features.is_torch_available", mock_torch_available): | ||
result = FeaturesManager.determine_framework(self.test_model) | ||
self.assertEqual(result, self.framework_tf) | ||
|
||
# Both in environment -> use PyTorch | ||
mock_tf_available = MagicMock(return_value=True) | ||
mock_torch_available = MagicMock(return_value=True) | ||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( | ||
"transformers.onnx.features.is_torch_available", mock_torch_available | ||
): | ||
result = FeaturesManager.determine_framework(self.test_model) | ||
self.assertEqual(result, self.framework_pt) | ||
|
||
# Both not in environment -> raise error | ||
mock_tf_available = MagicMock(return_value=False) | ||
mock_torch_available = MagicMock(return_value=False) | ||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( | ||
"transformers.onnx.features.is_torch_available", mock_torch_available | ||
): | ||
with self.assertRaises(EnvironmentError): | ||
result = FeaturesManager.determine_framework(self.test_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love to see the logic in this function unit tested if you're up for it, e.g. under
tests/onnx/test_features.py
You could use
SMALL_MODEL_IDENTIFIER
to save a tinytorch
/tf
model to a temporary directory as follows:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a unit test in b832988, but I put it under
tests/onnx/test_onnx_v2::OnnxUtilsTestCaseV2
. I just noticed you specifiedtest_features.py
, but it does not exist yet. I can create it if you'd like, or should I leave it as is?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!! Yes, please create a new
test_features.py
file for this test (we usually maptransformers/path/to/module.py
withtests/path/to/test_module.py
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Done in 67416f2.
In 8da5990 and I registered the tests in
utils/tests_fetcher.py
because of a failure I got in CI saying that the test would not be discovered. Is this the correct way to add them?In 63198fd I added
tf
for for the pip install steps forrun_tests_onnxruntime
andrun_tests_onnxruntime_all
in.circleci/config.yml
so thatTFAutoModel
can be used. Also added-rA
flags so that the results would be more verbose. In the logs forrun_tests_onnxruntime
it can be seen that the new unit tests are tested.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for registering the test - this is indeed the way to include it :)