Skip to content

Commit

Permalink
support tf rint op
Browse files Browse the repository at this point in the history
Signed-off-by: hwangdeyu <dejack953@outlook.com>
Co-authored-by: fatcat-z <jiz@microsoft.com>
  • Loading branch information
hwangdeyu and fatcat-z committed Jan 28, 2022
1 parent 59cfa79 commit 17a2880
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4609,6 +4609,14 @@ def func(x):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_opset_min_version(11, "Round")
def test_rint(self):
x_val = np.array([-2.7, -1.5, -0.0, +0.0, 0.3, 0.5, 1.5, 2.5, 3.4, 3.5, float('nan')], dtype=np.float32)
def func(x):
x_ = tf.math.rint(x)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_opset_min_version(11, "Det")
@unittest.skip("unclear how this is called in tf-2, fix later")
def test_determinant(self):
Expand Down
9 changes: 9 additions & 0 deletions tf2onnx/onnx_opset/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,15 @@ def version_11(cls, ctx, node, **kwargs):
pass


@tf_op("Rint", onnx_op="Round")
class Rint:
@classmethod
def version_11(cls, ctx, node, **kwargs):
# Same with tf round, two different people just happened to write the function.
# https://github.com/tensorflow/tensorflow/issues/709
pass


@tf_op("MatrixDeterminant", onnx_op="Det")
class Det:
@classmethod
Expand Down

0 comments on commit 17a2880

Please sign in to comment.