diff --git a/src/frontends/tensorflow_common/src/op/ones_like.cpp b/src/frontends/tensorflow_common/src/op/ones_like.cpp index 7e6a904dcf247a..6003f26ca3e34c 100644 --- a/src/frontends/tensorflow_common/src/op/ones_like.cpp +++ b/src/frontends/tensorflow_common/src/op/ones_like.cpp @@ -3,9 +3,11 @@ // #include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" #include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/squeeze.hpp" #include "utils.hpp" @@ -19,8 +21,28 @@ namespace tensorflow { namespace op { OutputVector translate_ones_like_op(const NodeContext& node) { - default_op_checks(node, 1, {"OnesLike"}); + default_op_checks(node, 1, {"OnesLike"}, true); auto x = node.get_input(0); + auto complex_type_mark_x = as_type_ptr(x.get_node_shared_ptr()); + if (complex_type_mark_x) { + x = complex_type_mark_x->input_value(0); + auto gather_index_real = make_shared(element::i32, Shape{1}, 0); + auto minus_one = make_shared(element::i32, Shape{1}, -1); + auto x_real = make_shared(x, gather_index_real, minus_one)->output(0); + Output shape_of_real = make_shared(x_real, element::i32); + + auto one_const = create_same_type_const_scalar(x_real, 1); + Output ones_like = make_shared(one_const, shape_of_real); + + auto zero_const = create_same_type_const_scalar(x_real, 0); + Output zeros_like = make_shared(zero_const, shape_of_real); + auto result = make_shared(OutputVector{ones_like, zeros_like}, -1); + set_node_name(node.get_name(), result); + auto ones_like_complex = make_shared(result, complex_type_mark_x->get_complex_part_type()); + + return {ones_like_complex}; + } + Output shape_of = make_shared(x, element::i32); auto one_const = create_same_type_const_scalar(x, 1); @@ -35,11 +57,9 @@ OutputVector translate_ones_like_op(const NodeContext& node) { // remove extra dimension by squeezing auto zero_dim_ind = make_shared(element::i32, Shape{1}, 0); ones_like = make_shared(ones_like, zero_dim_ind); - set_node_name(node.get_name(), ones_like.get_node_shared_ptr()); return {ones_like}; } - } // namespace op } // namespace tensorflow } // namespace frontend diff --git a/tests/layer_tests/tensorflow_tests/test_tf_OnesLike.py b/tests/layer_tests/tensorflow_tests/test_tf_OnesLike.py index 1a5cb7110e8288..0da2822155c8d9 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_OnesLike.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_OnesLike.py @@ -43,3 +43,48 @@ def test_ones_like(self, params, ie_device, precision, ir_version, temp_dir, self._test(*self.create_ones_like_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) + + +class TestComplexOnesLike(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + rng = np.random.default_rng() + assert 'x_real:0' in inputs_info + assert 'x_imag:0' in inputs_info + x_real_shape = inputs_info['x_real:0'] + x_imag_shape = inputs_info['x_imag:0'] + inputs_data = {} + inputs_data['x_real:0'] = 4 * rng.random(x_real_shape).astype(self.x_type) - 2 + inputs_data['x_imag:0'] = 4 * rng.random(x_imag_shape).astype(self.x_type) - 2 + return inputs_data + + def create_complex_ones_like_net(self, x_shape, x_type): + self.x_type = x_type + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + x_real = tf.compat.v1.placeholder(tf.dtypes.as_dtype(x_type), x_shape, 'x_real') + x_imag = tf.compat.v1.placeholder(tf.dtypes.as_dtype(x_type), x_shape, 'x_imag') + x_complex = tf.raw_ops.Complex(real=x_real, imag=x_imag) + ones_like = tf.raw_ops.OnesLike(x=x_complex) + real = tf.raw_ops.Real(input=ones_like) + img = tf.raw_ops.Imag(input=ones_like) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(x_shape=[], x_type=np.float32), + dict(x_shape=[2], x_type=np.float32), + dict(x_shape=[2, 3, 4], x_type=np.float32), + dict(x_shape=[1, 4, 3, 1], x_type=np.float32), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_complex_ones_like(self, params, ie_device, precision, ir_version, temp_dir, + use_legacy_frontend): + self._test(*self.create_complex_ones_like_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_legacy_frontend=use_legacy_frontend)