diff --git a/tests/common.py b/tests/common.py index 80144bb91..2d5f4103c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -24,6 +24,8 @@ "check_onnxruntime_backend", "check_tf_min_version", "check_tf_max_version", + "check_tfjs_min_version", + "check_tfjs_max_version", "skip_tf_versions", "skip_tf_cpu", "check_onnxruntime_min_version", @@ -272,6 +274,25 @@ def requires_custom_ops(message=""): can_import = False return unittest.skipIf(not can_import, reason) +def check_tfjs_max_version(max_accepted_version, message=""): + """ Skip if tfjs_version > max_required_version """ + reason = _append_message("conversion requires tensorflowjs <= {}".format(max_accepted_version), message) + try: + import tensorflowjs + can_import = True + except ModuleNotFoundError: + can_import = False + return unittest.skipIf(can_import and not config.skip_tfjs_tests and tensorflowjs.__version__ > LooseVersion(max_accepted_version), reason) + +def check_tfjs_min_version(min_required_version, message=""): + """ Skip if tjs_version < min_required_version """ + reason = _append_message("conversion requires tensorflowjs >= {}".format(min_required_version), message) + try: + import tensorflowjs + can_import = True + except ModuleNotFoundError: + can_import = False + return unittest.skipIf(can_import and not config.skip_tfjs_tests and tensorflowjs.__version__ < LooseVersion(min_required_version), reason) def check_tf_max_version(max_accepted_version, message=""): """ Skip if tf_version > max_required_version """ diff --git a/tests/test_cond.py b/tests/test_cond.py index 8d74c8dcc..7fa8c5dbc 100644 --- a/tests/test_cond.py +++ b/tests/test_cond.py @@ -118,6 +118,7 @@ def false_fn(): output_names_with_port = ["output:0"] self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port) + @check_tfjs_max_version("3.15", "failed when tfjs version > 3.15") def test_cond_in_while_loop(self): def func(i, inputs): inputs_2 = tf.identity(inputs) diff --git a/tests/test_loops.py b/tests/test_loops.py index b0b8f9213..410bee378 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -7,7 +7,8 @@ import tensorflow as tf from backend_test_base import Tf2OnnxBackendTestBase -from common import unittest_main, check_tf_min_version, check_tf_max_version, check_onnxruntime_min_version +from common import unittest_main, check_tf_min_version, check_tf_max_version, \ + check_onnxruntime_min_version, check_tfjs_max_version from tf2onnx.tf_loader import is_tf2 @@ -66,6 +67,7 @@ def func(i): x_val = np.array(3, dtype=np.int32) self.run_test_case(func, {_INPUT: x_val}, [], [_OUTPUT], rtol=1e-06) + @check_tfjs_max_version("3.15", "failed when tfjs version > 3.15") def test_while_loop_with_ta_write(self): def func(i): output_ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True) @@ -159,6 +161,7 @@ def b(i, res, res2): output_names_with_port = ["i:0", "x:0", "y:0"] self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06) + @check_tfjs_max_version("3.15", "failed when tfjs version > 3.15") def test_while_loop_with_ta_read_and_write(self): def func(i, inputs): inputs_2 = tf.identity(inputs) @@ -183,6 +186,7 @@ def b(i, out_ta): output_names_with_port = ["i:0", "output_ta:0"] self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06) + @check_tfjs_max_version("3.15", "failed when tfjs version > 3.15") def test_while_loop_with_multi_scan_outputs(self): def func(i, inputs1, inputs2): inputs1_ = tf.identity(inputs1) @@ -217,6 +221,7 @@ def b(i, out_ta, out_ta2): output_names_with_port = ["i:0", "output_ta:0", "output_ta2:0"] self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06) + @check_tfjs_max_version("3.15", "failed when tfjs version > 3.15") @check_onnxruntime_min_version( "0.5.0", "disable this case due to onnxruntime loop issue: https://github.com/microsoft/onnxruntime/issues/1272"