Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF FE] Support complex tensors for OnesLike operation #23445

Merged
merged 10 commits into from
Mar 21, 2024
15 changes: 13 additions & 2 deletions src/frontends/tensorflow_common/src/op/ones_like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ namespace op {
OutputVector translate_ones_like_op(const NodeContext& node) {
default_op_checks(node, 1, {"OnesLike"});
auto x = node.get_input(0);
auto complex_type_mark_x = as_type_ptr<ComplexTypeMark>(x.get_node_shared_ptr());
if (complex_type_mark_x) {
x = complex_type_mark_x->input_value(0);
rkazants marked this conversation as resolved.
Show resolved Hide resolved
}

Output<Node> shape_of = make_shared<v3::ShapeOf>(x, element::i32);
auto one_const = create_same_type_const_scalar<int32_t>(x, 1);

Expand All @@ -35,11 +40,17 @@ OutputVector translate_ones_like_op(const NodeContext& node) {
// remove extra dimension by squeezing
auto zero_dim_ind = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
ones_like = make_shared<v0::Squeeze>(ones_like, zero_dim_ind);

set_node_name(node.get_name(), ones_like.get_node_shared_ptr());

if (complex_type_mark_x) {
auto complex_ones_like = make_shared<ComplexTypeMark>(ones_like, complex_type_mark_x->get_complex_part_type());
return {
complex_ones_like
}
}

return {ones_like};
}

} // namespace op
} // namespace tensorflow
} // namespace frontend
Expand Down
44 changes: 44 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_OnesLike.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,47 @@ def test_ones_like(self, params, ie_device, precision, ir_version, temp_dir,
self._test(*self.create_ones_like_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

rkazants marked this conversation as resolved.
Show resolved Hide resolved
class TestComplexOnesLike(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
rng = np.random.default_rng()
assert 'x_real:0' in inputs_info
assert 'x_imag:0' in inputs_info
x_real_shape = inputs_info['x_real:0']
x_imag_shape = inputs_info['x_imag:0']
inputs_data = {}
inputs_data['x_real:0'] = 4 * rng.random(x_real_shape).astype(self.x_type) - 2
inputs_data['x_imag:0'] = 4 * rng.random(x_imag_shape).astype(self.x_type) - 2
return inputs_data

def create_complex_ones_like_net(self, x_shape, x_type):
self.x_type = x_type
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
x_real = tf.compat.v1.placeholder(tf.dtypes.as_dtype(x_type), x_shape, 'x_real')
x_imag = tf.compat.v1.placeholder(tf.dtypes.as_dtype(x_type), x_shape, 'x_imag')
x_complex = tf.raw_ops.Complex(real=x_real, imag=x_imag)
ones_like = tf.raw_ops.OnesLike(x=x_complex)
real = tf.raw_ops.Real(input=ones_like)
img = tf.raw_ops.Imag(input=ones_like)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

return tf_net, None

test_data_basic = [
dict(x_shape=[], x_type=np.float32),
dict(x_shape=[2], x_type=np.int32),
dict(x_shape=[2, 3, 4], x_type=np.float32),
dict(x_shape=[1, 4, 3, 1], x_type=np.int32),
rkazants marked this conversation as resolved.
Show resolved Hide resolved
]

@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_complex_ones_like(self, params, ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
self._test(*self.create_complex_ones_like_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)
Loading