diff --git a/src/frontends/tensorflow_common/src/op/reciprocal.cpp b/src/frontends/tensorflow_common/src/op/reciprocal.cpp index 91a38bbf798a1b..08e79c27c2c3df 100644 --- a/src/frontends/tensorflow_common/src/op/reciprocal.cpp +++ b/src/frontends/tensorflow_common/src/op/reciprocal.cpp @@ -3,6 +3,12 @@ // #include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/negative.hpp" #include "openvino/op/power.hpp" #include "utils.hpp" @@ -16,8 +22,34 @@ namespace op { OutputVector translate_reciprocal_op(const NodeContext& node) { // computes element-wise 1/x, where x - input - default_op_checks(node, 1, {"Reciprocal"}); + default_op_checks(node, 1, {"Reciprocal"}, true); auto x = node.get_input(0); + auto complex_type_mark_x = as_type_ptr(x.get_node_shared_ptr()); + if (complex_type_mark_x) { + x = complex_type_mark_x->input_value(0); + auto minus_one = make_shared(element::i32, Shape{1}, -1); + auto two = create_same_type_const_scalar(x, 2); + auto gather_index_real = make_shared(element::i32, Shape{1}, 0); + auto gather_index_imag = make_shared(element::i32, Shape{1}, 1); + auto x_real = make_shared(x, gather_index_real, minus_one)->output(0); + auto x_imag = make_shared(x, gather_index_imag, minus_one)->output(0); + + // compute (a^2+b^2) + auto real_squared_norm = make_shared(x_real, two); + auto img_squared_norm = make_shared(x_imag, two); + auto squared_norm = make_shared(real_squared_norm, img_squared_norm); + + // compute 1/(a+bi) = (a-bi)/(a^2+b^2) + auto complex_reciprocal = make_shared( + make_shared(OutputVector{x_real, make_shared(x_imag)}, -1), + squared_norm); + auto complex_result = + make_shared(complex_reciprocal, complex_type_mark_x->get_complex_part_type()); + set_node_name(node.get_name(), complex_reciprocal); + return {complex_result}; + } + + // For real numbers, computes element-wise 1/x, where x - input auto minus_one_const = create_same_type_const_scalar(x, -1); auto reciprocal = make_shared(x, minus_one_const); set_node_name(node.get_name(), reciprocal); diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Reciprocal.py b/tests/layer_tests/tensorflow_tests/test_tf_Reciprocal.py index cbf0223cb68acb..210615cf78eea3 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_Reciprocal.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_Reciprocal.py @@ -41,3 +41,46 @@ def test_reciprocal_basic(self, params, ie_device, precision, ir_version, temp_d self._test(*self.create_reciprocal_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) + +class TestComplexReciprocal(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + rng = np.random.default_rng() + assert 'param_real_1:0' in inputs_info + assert 'param_imag_1:0' in inputs_info + param_real_shape_1 = inputs_info['param_real_1:0'] + param_imag_shape_1 = inputs_info['param_imag_1:0'] + inputs_data = {} + inputs_data['param_real_1:0'] = 4 * rng.random(param_real_shape_1).astype(np.float32) - 2 + inputs_data['param_imag_1:0'] = 4 * rng.random(param_imag_shape_1).astype(np.float32) - 2 + + return inputs_data + + def create_complex_reciprocal_net(self, x_shape,x_type): + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + param_real1 = tf.compat.v1.placeholder(np.float32, x_shape, 'param_real_1') + param_imag1 = tf.compat.v1.placeholder(np.float32, x_shape, 'param_imag_1') + complex_x = tf.raw_ops.Complex(real=param_real1, imag=param_imag1) + reciprocal = tf.raw_ops.Reciprocal(x=complex_x) + real = tf.raw_ops.Real(input=reciprocal) + img = tf.raw_ops.Imag(input=reciprocal) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(x_shape=[], x_type=np.float32), + dict(x_shape=[2, 3], x_type=np.float32), + dict(x_shape=[4, 1, 3], x_type=np.float32), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_complex_reciprocal(self, params, ie_device, precision, ir_version, temp_dir, + use_legacy_frontend): + self._test(*self.create_complex_reciprocal_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_legacy_frontend=use_legacy_frontend) \ No newline at end of file