Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Determine framework automatically before ONNX export #18615

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ jobs:
- v0.5-torch-{{ checksum "setup.py" }}
- v0.5-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[torch,testing,sentencepiece,onnxruntime,vision,rjieba]
- run: pip install .[torch,tf,testing,sentencepiece,onnxruntime,vision,rjieba]
- save_cache:
key: v0.5-onnx-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -912,7 +912,7 @@ jobs:
- v0.5-torch-{{ checksum "setup.py" }}
- v0.5-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[torch,testing,sentencepiece,onnxruntime,vision]
- run: pip install .[torch,tf,testing,sentencepiece,onnxruntime,vision]
- save_cache:
key: v0.5-onnx-{{ checksum "setup.py" }}
paths:
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,15 @@ def main():
"--atol", type=float, default=None, help="Absolute difference tolerence when validating the model."
)
parser.add_argument(
"--framework", type=str, choices=["pt", "tf"], default="pt", help="The framework to use for the ONNX export."
"--framework",
type=str,
choices=["pt", "tf"],
default=None,
help=(
"The framework to use for the ONNX export."
" If not provided, will attempt to use the local checkpoint's original framework"
" or what is available in the environment."
),
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
Expand Down
63 changes: 59 additions & 4 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from functools import partial, reduce
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union

import transformers

from .. import PretrainedConfig, is_tf_available, is_torch_available
from ..utils import logging
from ..utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging
from .config import OnnxConfig


Expand Down Expand Up @@ -557,9 +558,59 @@ def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type:
)
return task_to_automodel[task]

@staticmethod
def determine_framework(model: str, framework: str = None) -> str:
"""
Determines the framework to use for the export.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love to see the logic in this function unit tested if you're up for it, e.g. under tests/onnx/test_features.py

You could use SMALL_MODEL_IDENTIFIER to save a tiny torch / tf model to a temporary directory as follows:

# Ditto for the TF case
model = AutoModel.from_pretrained(SMALL_MODEL_IDENTIFIER)

with tempfile.TemporaryDirectory() as tmp_dir:
    model.save_pretrained(tmp_dir)
    framework = determine_framework(tmp_dir)
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a unit test in b832988, but I put it under tests/onnx/test_onnx_v2::OnnxUtilsTestCaseV2. I just noticed you specified test_features.py, but it does not exist yet. I can create it if you'd like, or should I leave it as is?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!! Yes, please create a new test_features.py file for this test (we usually map transformers/path/to/module.py with tests/path/to/test_module.py)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Done in 67416f2.

In 8da5990 and I registered the tests in utils/tests_fetcher.py because of a failure I got in CI saying that the test would not be discovered. Is this the correct way to add them?

In 63198fd I added tf for for the pip install steps for run_tests_onnxruntime and run_tests_onnxruntime_all in .circleci/config.yml so that TFAutoModel can be used. Also added -rA flags so that the results would be more verbose. In the logs for run_tests_onnxruntime it can be seen that the new unit tests are tested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for registering the test - this is indeed the way to include it :)


The priority is in the following order:
1. User input via `framework`.
2. If local checkpoint is provided, use the same framework as the checkpoint.
3. Available framework in environment, with priority given to PyTorch

Args:
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `None`):
The framework to use for the export. See above for priority if none provided.

Returns:
The framework to use for the export.

"""
if framework is not None:
return framework

framework_map = {"pt": "PyTorch", "tf": "TensorFlow"}
exporter_map = {"pt": "torch", "tf": "tf2onnx"}

if os.path.isdir(model):
if os.path.isfile(os.path.join(model, WEIGHTS_NAME)):
framework = "pt"
elif os.path.isfile(os.path.join(model, TF2_WEIGHTS_NAME)):
framework = "tf"
else:
raise FileNotFoundError(
"Cannot determine framework from given checkpoint location."
f" There should be a {WEIGHTS_NAME} for PyTorch"
f" or {TF2_WEIGHTS_NAME} for TensorFlow."
)
logger.info(f"Local {framework_map[framework]} model found.")
else:
if is_torch_available():
framework = "pt"
elif is_tf_available():
framework = "tf"
else:
raise EnvironmentError("Neither PyTorch nor TensorFlow found in environment. Cannot export to ONNX.")

logger.info(f"Framework not requested. Using {exporter_map[framework]} to export to ONNX.")

return framework

@staticmethod
def get_model_from_feature(
feature: str, model: str, framework: str = "pt", cache_dir: str = None
feature: str, model: str, framework: str = None, cache_dir: str = None
) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
"""
Attempts to retrieve a model from a model's name and the feature to be enabled.
Expand All @@ -569,20 +620,24 @@ def get_model_from_feature(
The feature required.
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.
framework (`str`, *optional*, defaults to `None`):
The framework to use for the export. See `FeaturesManager.determine_framework` for the priority should
none be provided.

Returns:
The instance of the model.

"""
framework = FeaturesManager.determine_framework(model, framework)
model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
try:
model = model_class.from_pretrained(model, cache_dir=cache_dir)
except OSError:
if framework == "pt":
logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea to log these steps for the user!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It helped me figure out the behavior, so hope it's helpful for others!

model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir)
else:
logger.info("Loading PyTorch model in TensorFlow before exporting to ONNX.")
model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir)
return model

Expand Down
111 changes: 111 additions & 0 deletions tests/onnx/test_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from tempfile import TemporaryDirectory
from unittest import TestCase
from unittest.mock import MagicMock, patch

from transformers import AutoModel, TFAutoModel
from transformers.onnx import FeaturesManager
from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch


@require_torch
@require_tf
class DetermineFrameworkTest(TestCase):
"""
Test `FeaturesManager.determine_framework`
"""

def setUp(self):
self.test_model = SMALL_MODEL_IDENTIFIER
self.framework_pt = "pt"
self.framework_tf = "tf"

def _setup_pt_ckpt(self, save_dir):
model_pt = AutoModel.from_pretrained(self.test_model)
model_pt.save_pretrained(save_dir)

def _setup_tf_ckpt(self, save_dir):
model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True)
model_tf.save_pretrained(save_dir)

def test_framework_provided(self):
"""
Ensure the that the provided framework is returned.
"""
mock_framework = "mock_framework"

# Framework provided - return whatever the user provides
result = FeaturesManager.determine_framework(self.test_model, mock_framework)
self.assertEqual(result, mock_framework)

# Local checkpoint and framework provided - return provided framework
# PyTorch checkpoint
with TemporaryDirectory() as local_pt_ckpt:
self._setup_pt_ckpt(local_pt_ckpt)
result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework)
self.assertEqual(result, mock_framework)

# TensorFlow checkpoint
with TemporaryDirectory() as local_tf_ckpt:
self._setup_tf_ckpt(local_tf_ckpt)
result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework)
self.assertEqual(result, mock_framework)

def test_checkpoint_provided(self):
"""
Ensure that the determined framework is the one used for the local checkpoint.

For the functionality to execute, local checkpoints are provided but framework is not.
"""
# PyTorch checkpoint
with TemporaryDirectory() as local_pt_ckpt:
self._setup_pt_ckpt(local_pt_ckpt)
result = FeaturesManager.determine_framework(local_pt_ckpt)
self.assertEqual(result, self.framework_pt)

# TensorFlow checkpoint
with TemporaryDirectory() as local_tf_ckpt:
self._setup_tf_ckpt(local_tf_ckpt)
result = FeaturesManager.determine_framework(local_tf_ckpt)
self.assertEqual(result, self.framework_tf)

# Invalid local checkpoint
with TemporaryDirectory() as local_invalid_ckpt:
with self.assertRaises(FileNotFoundError):
result = FeaturesManager.determine_framework(local_invalid_ckpt)

def test_from_environment(self):
"""
Ensure that the determined framework is the one available in the environment.

For the functionality to execute, framework and local checkpoints are not provided.
"""
# Framework not provided, hub model is used (no local checkpoint directory)
# TensorFlow not in environment -> use PyTorch
mock_tf_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, self.framework_pt)

# PyTorch not in environment -> use TensorFlow
mock_torch_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_torch_available", mock_torch_available):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, self.framework_tf)

# Both in environment -> use PyTorch
mock_tf_available = MagicMock(return_value=True)
mock_torch_available = MagicMock(return_value=True)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
"transformers.onnx.features.is_torch_available", mock_torch_available
):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, self.framework_pt)

# Both not in environment -> raise error
mock_tf_available = MagicMock(return_value=False)
mock_torch_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
"transformers.onnx.features.is_torch_available", mock_torch_available
):
with self.assertRaises(EnvironmentError):
result = FeaturesManager.determine_framework(self.test_model)
2 changes: 1 addition & 1 deletion utils/tests_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def module_to_test_file(module_fname):
return "tests/utils/test_cli.py"
# Special case for onnx submodules
elif len(splits) >= 2 and splits[-2] == "onnx":
return ["tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"]
return ["tests/onnx/test_features.py", "tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"]
# Special case for utils (not the one in src/transformers, the ones at the root of the repo).
elif len(splits) > 0 and splits[0] == "utils":
default_test_file = f"tests/utils/test_utils_{module_name}"
Expand Down