Skip to content

Commit

Permalink
skip tfjs 3.17 tests
Browse files Browse the repository at this point in the history
Signed-off-by: Deyu Huang <deyhuang@microsoft.com>
  • Loading branch information
hwangdeyu committed May 18, 2022
1 parent 772dbe6 commit f8dd91c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
21 changes: 21 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"check_onnxruntime_backend",
"check_tf_min_version",
"check_tf_max_version",
"check_tfjs_min_version",
"check_tfjs_max_version",
"skip_tf_versions",
"skip_tf_cpu",
"check_onnxruntime_min_version",
Expand Down Expand Up @@ -272,6 +274,25 @@ def requires_custom_ops(message=""):
can_import = False
return unittest.skipIf(not can_import, reason)

def check_tfjs_max_version(max_accepted_version, message=""):
""" Skip if tfjs_version > max_required_version """
reason = _append_message("conversion requires tensorflowjs <= {}".format(max_accepted_version), message)
try:
import tensorflowjs
can_import = True
except ModuleNotFoundError:
can_import = False
return unittest.skipIf(can_import and not config.skip_tfjs_tests and tensorflowjs.__version__ > LooseVersion(max_accepted_version), reason)

def check_tfjs_min_version(min_required_version, message=""):
""" Skip if tjs_version < min_required_version """
reason = _append_message("conversion requires tensorflowjs >= {}".format(min_required_version), message)
try:
import tensorflowjs
can_import = True
except ModuleNotFoundError:
can_import = False
return unittest.skipIf(can_import and not config.skip_tfjs_tests and tensorflowjs.__version__ < LooseVersion(min_required_version), reason)

def check_tf_max_version(max_accepted_version, message=""):
""" Skip if tf_version > max_required_version """
Expand Down
1 change: 1 addition & 0 deletions tests/test_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def false_fn():
output_names_with_port = ["output:0"]
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port)

@check_tfjs_max_version("3.15", "failed when tfjs version > 3.15")
def test_cond_in_while_loop(self):
def func(i, inputs):
inputs_2 = tf.identity(inputs)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import tensorflow as tf

from backend_test_base import Tf2OnnxBackendTestBase
from common import unittest_main, check_tf_min_version, check_tf_max_version, check_onnxruntime_min_version
from common import unittest_main, check_tf_min_version, check_tf_max_version, \
check_onnxruntime_min_version, check_tfjs_max_version
from tf2onnx.tf_loader import is_tf2


Expand Down Expand Up @@ -66,6 +67,7 @@ def func(i):
x_val = np.array(3, dtype=np.int32)
self.run_test_case(func, {_INPUT: x_val}, [], [_OUTPUT], rtol=1e-06)

@check_tfjs_max_version("3.15", "failed when tfjs version > 3.15")
def test_while_loop_with_ta_write(self):
def func(i):
output_ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
Expand Down Expand Up @@ -159,6 +161,7 @@ def b(i, res, res2):
output_names_with_port = ["i:0", "x:0", "y:0"]
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)

@check_tfjs_max_version("3.15", "failed when tfjs version > 3.15")
def test_while_loop_with_ta_read_and_write(self):
def func(i, inputs):
inputs_2 = tf.identity(inputs)
Expand All @@ -183,6 +186,7 @@ def b(i, out_ta):
output_names_with_port = ["i:0", "output_ta:0"]
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)

@check_tfjs_max_version("3.15", "failed when tfjs version > 3.15")
def test_while_loop_with_multi_scan_outputs(self):
def func(i, inputs1, inputs2):
inputs1_ = tf.identity(inputs1)
Expand Down Expand Up @@ -217,6 +221,7 @@ def b(i, out_ta, out_ta2):
output_names_with_port = ["i:0", "output_ta:0", "output_ta2:0"]
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)

@check_tfjs_max_version("3.15", "failed when tfjs version > 3.15")
@check_onnxruntime_min_version(
"0.5.0",
"disable this case due to onnxruntime loop issue: https://github.com/microsoft/onnxruntime/issues/1272"
Expand Down

0 comments on commit f8dd91c

Please sign in to comment.