Skip to content

Commit

Permalink
[Relay][TF] Support for Atan/Atan2 in Relay Tensorflow frontend conve…
Browse files Browse the repository at this point in the history
…rter. (#5104)

* add Atan/Atan2 op

* fix bug and testing
  • Loading branch information
hypercubestart authored Mar 20, 2020
1 parent 2b66123 commit a94f69f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
7 changes: 7 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,11 @@ def _impl(inputs, attr, params):

return _impl

def _atan2():
def _impl(inputs, attr, params):
divide = _elemwise("divide")(inputs, attr, params)
return get_relay_op("atan")(divide)
return _impl

def _prod():
def _impl(inputs, attr, params):
Expand Down Expand Up @@ -1615,6 +1620,8 @@ def _impl(inputs, attr, params):
'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'),
'Assert' : _assert(),
'Atan' : AttrCvt('atan'),
'Atan2' : _atan2(),
'AvgPool' : _pooling('avg_pool'),
'AvgPool3D' : _pool3d('avg_pool3d'),
'BatchMatMul' : _batch_matmul(),
Expand Down
22 changes: 22 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2669,6 +2669,26 @@ def test_forward_tan():
tf.tan(in_data, name="tan")
compare_tf_with_tvm([np_data], ['in_data:0'], 'tan:0')

def test_forward_atan():
"""test operator tan """
tf.disable_eager_execution()
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.atan(in_data, name="atan")
compare_tf_with_tvm([np_data], ['in_data:0'], 'atan:0')

def test_forward_atan2():
"""test operator tan """
tf.disable_eager_execution()
np_data_1 = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
np_data_2 = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data_1 = tf.placeholder(tf.float32, (2, 3, 5), name="in_data_1")
in_data_2 = tf.placeholder(tf.float32, (2, 3, 5), name="in_data_2")
tf.atan2(in_data_1, in_data_2, name="atan2")
compare_tf_with_tvm([np_data_1, np_data_2], ['in_data_1:0', 'in_data_2:0'], 'atan2:0')


def test_forward_sin():
"""test operator sin """
Expand Down Expand Up @@ -3116,6 +3136,8 @@ def test_forward_dilation():
test_forward_left_shift()
test_forward_truncatemod()
test_forward_one_hot()
test_forward_atan()
test_forward_atan2()

# Activations
test_forward_sigmoid()
Expand Down

0 comments on commit a94f69f

Please sign in to comment.