Skip to content

Commit

Permalink
[TF FE] Support Complex Tensors (#20860)
Browse files Browse the repository at this point in the history
* [TF FE] Support complex tensors

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Align output type for Real and Imag operations

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Update decoding complex types

* Add support for ComplexAbs, FFT and IFFT operations

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Correct axes based on a number of inner-most dimensions

* Add layer tests

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Update supported ops documentation

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Add a comment for ComplexTypeMark

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
  • Loading branch information
rkazants authored Nov 6, 2023
1 parent 1083b3b commit d0eb27b
Show file tree
Hide file tree
Showing 17 changed files with 804 additions and 31 deletions.
32 changes: 16 additions & 16 deletions src/frontends/tensorflow/docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| CollectiveReduceV2 | NO | |
| CollectiveReduceV3 | NO | |
| CombinedNonMaxSuppression | NO | |
| Complex | NO | |
| ComplexAbs | NO | |
| Complex | YES | |
| ComplexAbs | YES | |
| CompositeTensorVariantFromComponents | NO | |
| CompositeTensorVariantToComponents | NO | |
| CompressElement | NO | |
Expand Down Expand Up @@ -425,9 +425,9 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| ExtractImagePatches | YES | |
| ExtractJpegShape | NO | |
| ExtractVolumePatches | NO | |
| FFT | NO | |
| FFT2D | NO | |
| FFT3D | NO | |
| FFT | YES | |
| FFT2D | YES | |
| FFT3D | YES | |
| FIFOQueue | YES | |
| FIFOQueueV2 | YES | |
| Fact | NO | |
Expand Down Expand Up @@ -492,12 +492,12 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| HashTableV2 | YES | |
| HistogramFixedWidth | NO | |
| HistogramSummary | NO | |
| IFFT | NO | |
| IFFT2D | NO | |
| IFFT3D | NO | |
| IRFFT | NO | |
| IRFFT2D | NO | |
| IRFFT3D | NO | |
| IFFT | YES | |
| IFFT2D | YES | |
| IFFT3D | YES | |
| IRFFT | YES | |
| IRFFT2D | YES | |
| IRFFT3D | YES | |
| Identity | YES | |
| IdentityN | YES | |
| IdentityReader | NO | |
Expand All @@ -507,7 +507,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| IgammaGradA | NO | |
| Igammac | NO | |
| IgnoreErrorsDataset | NO | |
| Imag | NO | |
| Imag | YES | |
| ImageProjectiveTransformV2 | NO | |
| ImageProjectiveTransformV3 | NO | |
| ImageSummary | NO | |
Expand Down Expand Up @@ -826,9 +826,9 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| QueueIsClosedV2 | NO | |
| QueueSize | NO | |
| QueueSizeV2 | NO | |
| RFFT | NO | |
| RFFT2D | NO | |
| RFFT3D | NO | |
| RFFT | YES | |
| RFFT2D | YES | |
| RFFT3D | YES | |
| RGBToHSV | NO | |
| RaggedBincount | NO | |
| RaggedCountSparseOutput | NO | |
Expand Down Expand Up @@ -876,7 +876,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| ReaderRestoreStateV2 | NO | |
| ReaderSerializeState | NO | |
| ReaderSerializeStateV2 | NO | |
| Real | NO | |
| Real | YES | |
| RealDiv | YES | |
| RebatchDataset | NO | |
| RebatchDatasetV2 | NO | |
Expand Down
10 changes: 7 additions & 3 deletions src/frontends/tensorflow/src/decoder_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,14 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {

case ::tensorflow::AttrValue::ValueCase::kType: {
auto atype = attrs[0].type();
if (atype != ::tensorflow::DT_STRING) {
return get_ov_type(attrs[0].type());
} else {
if (atype == ::tensorflow::DT_STRING) {
return ov::Any("DT_STRING");
} else if (atype == ::tensorflow::DT_COMPLEX64) {
return ov::Any("DT_COMPLEX64");
} else if (atype == ::tensorflow::DT_COMPLEX128) {
return ov::Any("DT_COMPLEX128");
} else {
return get_ov_type(atype);
}
}

Expand Down
16 changes: 16 additions & 0 deletions src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"CheckNumerics", CreatorFunction(translate_identity_op)},
{"CheckNumericsV2", CreatorFunction(translate_identity_op)},
{"ClipByValue", CreatorFunction(translate_clip_by_value_op)},
{"Complex", CreatorFunction(translate_complex_op)},
{"ComplexAbs", CreatorFunction(translate_complex_abs_op)},
{"Concat", CreatorFunction(translate_concat_op)},
{"ConcatV2", CreatorFunction(translate_concat_op)},
{"Const", CreatorFunction(translate_const_op)},
Expand All @@ -178,6 +180,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"FakeQuantWithMinMaxVars", CreatorFunction(translate_fake_quant_op)},
{"FakeQuantWithMinMaxVarsPerChannel", CreatorFunction(translate_fake_quant_op)},
{"FakeQuantWithMinMaxArgs", CreatorFunction(translate_fake_quant_with_min_max_args)},
{"FFT", CreatorFunction(translate_fft_op)},
{"FFT2D", CreatorFunction(translate_fft_op)},
{"FFT3D", CreatorFunction(translate_fft_op)},
{"FIFOQueue", CreatorFunction(translate_fifo_queue_op)},
{"FIFOQueueV2", CreatorFunction(translate_fifo_queue_op)},
{"Fill", CreatorFunction(translate_fill_op)},
Expand All @@ -196,7 +201,14 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"IdentityN", CreatorFunction(translate_identity_n_op)},
{"Inv", CreatorFunction(translate_inv_op)},
{"If", CreatorFunction(translate_if_op)},
{"IFFT", CreatorFunction(translate_ifft_op)},
{"IFFT2D", CreatorFunction(translate_ifft_op)},
{"IFFT3D", CreatorFunction(translate_ifft_op)},
{"Imag", CreatorFunction(translate_real_imag_op)},
{"input_arg", CreatorFunction(translate_input_arg_op)},
{"IRFFT", CreatorFunction(translate_irfft_op)},
{"IRFFT2D", CreatorFunction(translate_irfft_op)},
{"IRFFT3D", CreatorFunction(translate_irfft_op)},
{"Iterator", CreatorFunction(translate_iterator_op)},
{"IteratorGetNext", CreatorFunction(translate_iterator_get_next_op)},
{"IteratorV2", CreatorFunction(translate_iterator_op)},
Expand Down Expand Up @@ -248,6 +260,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Rank", CreatorFunction(translate_rank_op)},
{"RandomUniform", CreatorFunction(translate_random_uniform_op)},
{"RandomUniformInt", CreatorFunction(translate_random_uniform_int_op)},
{"Real", CreatorFunction(translate_real_imag_op)},
{"Reciprocal", CreatorFunction(translate_reciprocal_op)},
{"Relu6", CreatorFunction(translate_relu_6_op)},
{"Reshape", CreatorFunction(translate_reshape_op)},
Expand All @@ -257,6 +270,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"ResizeBilinear", CreatorFunction(translate_interpolate_op)},
{"ResizeNearestNeighbor", CreatorFunction(translate_interpolate_op)},
{"ResourceGather", CreatorFunction(translate_resource_gather_op)},
{"RFFT", CreatorFunction(translate_rfft_op)},
{"RFFT2D", CreatorFunction(translate_rfft_op)},
{"RFFT3D", CreatorFunction(translate_rfft_op)},
{"Roll", CreatorFunction(translate_roll_op)},
{"Round", CreatorFunction(translate_round_op)},
{"Rsqrt", CreatorFunction(translate_rsqrt_op)},
Expand Down
7 changes: 7 additions & 0 deletions src/frontends/tensorflow_common/include/common_op_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ OP_CONVERTER(translate_broadcast_to_op);
OP_CONVERTER(translate_bucketize_op);
OP_CONVERTER(translate_cast_op);
OP_CONVERTER(translate_clip_by_value_op);
OP_CONVERTER(translate_complex_op);
OP_CONVERTER(translate_complex_abs_op);
OP_CONVERTER(translate_concat_op);
OP_CONVERTER(translate_const_op);
OP_CONVERTER(translate_conv_2d_op);
Expand All @@ -66,6 +68,7 @@ OP_CONVERTER(translate_expand_dims_op);
OP_CONVERTER(translate_extract_image_patches_op);
OP_CONVERTER(translate_fake_quant_op);
OP_CONVERTER(translate_fake_quant_with_min_max_args);
OP_CONVERTER(translate_fft_op);
OP_CONVERTER(translate_fill_op);
OP_CONVERTER(translate_floor_div_op);
OP_CONVERTER_NAMED(translate_fused_batch_norm_op);
Expand All @@ -75,11 +78,13 @@ OP_CONVERTER(translate_gather_nd_op);
OP_CONVERTER(translate_gather_tree_op);
OP_CONVERTER(translate_identity_op);
OP_CONVERTER(translate_identity_n_op);
OP_CONVERTER(translate_ifft_op);
OP_CONVERTER(translate_input_arg_op);
OP_CONVERTER(translate_inv_op);
OP_CONVERTER(translate_invert_permutation_op);
OP_CONVERTER(translate_output_arg_op);
OP_CONVERTER(translate_interpolate_op);
OP_CONVERTER(translate_irfft_op);
OP_CONVERTER(translate_is_finite_op);
OP_CONVERTER(translate_is_inf_op);
OP_CONVERTER(translate_is_nan_op);
Expand Down Expand Up @@ -109,13 +114,15 @@ OP_CONVERTER(translate_range_op);
OP_CONVERTER(translate_rank_op);
OP_CONVERTER(translate_random_uniform_op);
OP_CONVERTER(translate_random_uniform_int_op);
OP_CONVERTER(translate_real_imag_op);
OP_CONVERTER(translate_relu_6_op);
OP_CONVERTER(translate_reciprocal_op);
OP_CONVERTER(translate_reshape_op);
OP_CONVERTER(translate_resource_gather_op);
OP_CONVERTER(translate_reverse_op);
OP_CONVERTER(translate_reverse_v2_op);
OP_CONVERTER(translate_reverse_sequence_op);
OP_CONVERTER(translate_rfft_op);
OP_CONVERTER(translate_roll_op);
OP_CONVERTER(translate_round_op);
OP_CONVERTER(translate_rsqrt_op);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/type/element_type.hpp"
#include "openvino/op/util/framework_node.hpp"

namespace ov {
namespace frontend {
namespace tensorflow {

// ComplexTypeMark serves to mark places that require complex type propagation
// that means to represent native complex type with simulating floating-point tensor
// that has one extra dimension to concatenate real and imaginary parts of complex tensor.
// For example, a tensor of complex type with shape [N1, N2, ..., Nk] will be transformed
// into a floating-point tensor [N1, N2, ..., Nk, 2]
// where a slice with index [..., 0] represents a real part and
// a slice with index [..., 1] represents a imaginary part.
class ComplexTypeMark : public ov::op::util::FrameworkNode {
public:
OPENVINO_OP("ComplexTypeMark", "util", ov::op::util::FrameworkNode);

ComplexTypeMark(const ov::Output<ov::Node>& input, const ov::element::Type& complex_part_type)
: ov::op::util::FrameworkNode(ov::OutputVector{input}, 1),
m_complex_part_type(complex_part_type) {
validate_and_infer_types();
}

void validate_and_infer_types() override {
set_output_type(0, ov::element::dynamic, PartialShape::dynamic());
}

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override {
auto complex_type_mark = std::make_shared<ComplexTypeMark>(inputs[0], m_complex_part_type);
complex_type_mark->set_attrs(get_attrs());
return complex_type_mark;
}

ov::element::Type get_complex_part_type() const {
return m_complex_part_type;
}

private:
ov::element::Type m_complex_part_type;
};

} // namespace tensorflow
} // namespace frontend
} // namespace ov
7 changes: 6 additions & 1 deletion src/frontends/tensorflow_common/include/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ void fill_explicit_pads_vectors(const NodeContext& node,
ov::CoordinateDiff& pads_begin,
ov::CoordinateDiff& pads_end);

void default_op_checks(const NodeContext& node, size_t min_input_size, const std::vector<std::string>& supported_ops);
void default_op_checks(const NodeContext& node,
size_t min_input_size,
const std::vector<std::string>& supported_ops,
bool supported_complex = false);

ov::Output<Node> get_elements_number_1d(const Output<Node>& output,
ov::element::Type output_type,
Expand Down Expand Up @@ -155,6 +158,8 @@ ov::Output<ov::Node> get_data_slice(const ov::Output<ov::Node>& data,
const int64_t& stop,
const int64_t& step);

ov::Output<ov::Node> compute_broadcast_args(const ov::Output<ov::Node>& shape1, const ov::Output<ov::Node>& shape2);

} // namespace tensorflow
} // namespace frontend
} // namespace ov
Loading

0 comments on commit d0eb27b

Please sign in to comment.