Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update from keras2onnx to tf2onnx #15162

Merged
merged 1 commit into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@
"jax>=0.2.8",
"jaxlib>=0.1.65",
"jieba",
"keras2onnx",
"nltk",
"numpy>=1.17",
"onnxconverter-common",
Expand Down Expand Up @@ -147,6 +146,7 @@
"starlette",
"tensorflow-cpu>=2.3",
"tensorflow>=2.3",
"tf2onnx",
"timeout-decorator",
"timm",
"tokenizers>=0.10.1",
Expand Down Expand Up @@ -229,8 +229,8 @@ def run(self):
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic")
extras["sklearn"] = deps_list("scikit-learn")

extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "keras2onnx")
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "keras2onnx")
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx")
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx")

extras["torch"] = deps_list("torch")

Expand All @@ -243,7 +243,7 @@ def run(self):

extras["tokenizers"] = deps_list("tokenizers")
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxruntime"]
extras["onnx"] = deps_list("onnxconverter-common", "tf2onnx") + extras["onnxruntime"]
extras["modelcreation"] = deps_list("cookiecutter")

extras["sagemaker"] = deps_list("sagemaker")
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/convert_graph_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format

def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
"""
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)

Args:
nlp: The pipeline to be exported
Expand All @@ -312,10 +312,10 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
try:
import tensorflow as tf

from keras2onnx import __version__ as k2ov
from keras2onnx import convert_keras, save_model
from tf2onnx import __version__ as t2ov
from tf2onnx import convert_keras, save_model

print(f"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}")
print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")

# Build
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"jax": "jax>=0.2.8",
"jaxlib": "jaxlib>=0.1.65",
"jieba": "jieba",
"keras2onnx": "keras2onnx",
"nltk": "nltk",
"numpy": "numpy>=1.17",
"onnxconverter-common": "onnxconverter-common",
Expand Down Expand Up @@ -57,6 +56,7 @@
"starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.3",
"tensorflow": "tensorflow>=2.3",
"tf2onnx": "tf2onnx",
"timeout-decorator": "timeout-decorator",
"timm": "timm",
"tokenizers": "tokenizers>=0.10.1",
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,12 @@
_sympy_available = False


_keras2onnx_available = importlib.util.find_spec("keras2onnx") is not None
_tf2onnx_available = importlib.util.find_spec("tf2onnx") is not None
try:
_keras2onnx_version = importlib_metadata.version("keras2onnx")
logger.debug(f"Successfully imported keras2onnx version {_keras2onnx_version}")
_tf2onnx_version = importlib_metadata.version("tf2onnx")
logger.debug(f"Successfully imported tf2onnx version {_tf2onnx_version}")
except importlib_metadata.PackageNotFoundError:
_keras2onnx_available = False
_tf2onnx_available = False

_onnx_available = importlib.util.find_spec("onnxruntime") is not None
try:
Expand Down Expand Up @@ -429,8 +429,8 @@ def is_coloredlogs_available():
return _coloredlogs_available


def is_keras2onnx_available():
return _keras2onnx_available
def is_tf2onnx_available():
return _tf2onnx_available


def is_onnx_available():
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
is_faiss_available,
is_flax_available,
is_ftfy_available,
is_keras2onnx_available,
is_librosa_available,
is_onnx_available,
is_pandas_available,
Expand All @@ -49,6 +48,7 @@
is_soundfile_availble,
is_spacy_available,
is_tensorflow_probability_available,
is_tf2onnx_available,
is_tf_available,
is_timm_available,
is_tokenizers_available,
Expand Down Expand Up @@ -246,9 +246,9 @@ def require_rjieba(test_case):
return test_case


def require_keras2onnx(test_case):
if not is_keras2onnx_available():
return unittest.skip("test requires keras2onnx")(test_case)
def require_tf2onnx(test_case):
if not is_tf2onnx_available():
return unittest.skip("test requires tf2onnx")(test_case)
else:
return test_case

Expand Down
10 changes: 5 additions & 5 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
_tf_gpu_memory_limit,
is_pt_tf_cross_test,
is_staging_test,
require_keras2onnx,
require_tf,
require_tf2onnx,
slow,
)
from transformers.utils import logging
Expand Down Expand Up @@ -254,24 +254,24 @@ def test_onnx_compliancy(self):

self.assertEqual(len(incompatible_ops), 0, incompatible_ops)

@require_keras2onnx
@require_tf2onnx
@slow
def test_onnx_runtime_optimize(self):
if not self.test_onnx:
return

import keras2onnx
import onnxruntime
import tf2onnx

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
model = model_class(config)
model(model.dummy_inputs)

onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)

onnxruntime.InferenceSession(onnx_model.SerializeToString())
onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())

def test_keras_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down