From d5df53703c4aee2bbbdc3bd4c1c86ad7d5d65ae2 Mon Sep 17 00:00:00 2001 From: achetver Date: Wed, 29 Sep 2021 17:44:49 +0300 Subject: [PATCH 01/29] Add GatherND_8 operation --- model-optimizer/extensions/ops/gathernd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/model-optimizer/extensions/ops/gathernd.py b/model-optimizer/extensions/ops/gathernd.py index 1175698fa96b89..2cf6f7e5420e4e 100644 --- a/model-optimizer/extensions/ops/gathernd.py +++ b/model-optimizer/extensions/ops/gathernd.py @@ -16,7 +16,7 @@ def __init__(self, graph: Graph, attrs: dict): mandatory_props = { 'type': self.op, 'op': self.op, - 'version': 'opset5', + 'version': 'opset8', 'infer': self.infer, 'in_ports_count': 2, 'out_ports_count': 1, @@ -60,7 +60,9 @@ def infer(node: Node): # compute output shape if batch_dims > 0: if is_fully_defined(data_shape[:batch_dims]): - batch = [np.prod(data_shape[:batch_dims]).tolist()] + batch = data_shape[:batch_dims].tolist() + if node['version'] == 'opset5': # Support old version of gather + batch = [np.prod(data_shape[:batch_dims]).tolist()] else: batch = [dynamic_dimension_value] else: From e7e9f1968bf484716305ccf3fa90c45b0c1a2606 Mon Sep 17 00:00:00 2001 From: achetver Date: Fri, 1 Oct 2021 10:25:14 +0300 Subject: [PATCH 02/29] Update shape infer function and tests --- model-optimizer/extensions/ops/gathernd.py | 48 +++++++++---------- .../extensions/ops/gathernd_test.py | 14 +++--- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/model-optimizer/extensions/ops/gathernd.py b/model-optimizer/extensions/ops/gathernd.py index 2cf6f7e5420e4e..126960ae059261 100644 --- a/model-optimizer/extensions/ops/gathernd.py +++ b/model-optimizer/extensions/ops/gathernd.py @@ -59,8 +59,8 @@ def infer(node: Node): # compute output shape if batch_dims > 0: - if is_fully_defined(data_shape[:batch_dims]): - batch = data_shape[:batch_dims].tolist() + if is_fully_defined(indices_shape[:batch_dims]): + batch = indices_shape[:batch_dims].tolist() if node['version'] == 'opset5': # Support old version of gather batch = [np.prod(data_shape[:batch_dims]).tolist()] else: @@ -68,31 +68,29 @@ def infer(node: Node): else: batch = [] slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):]) - output_shape = batch + list(indices_shape[batch_dims:-1]) + slice_shape + + batch_dims_size = 1 + + for i in range(batch_dims): + batch_dims_size *= indices_shape[i] + + if indices_shape[-1] == len(data_shape) - batch_dims: + output_shape = batch + list(indices_shape)[batch_dims:-1] + else: + output_shape = batch + list(indices_shape)[batch_dims:-1] + slice_shape node.out_port(0).data.set_shape(output_shape) # compute output value if all input values are defined if is_fully_defined(indices_value) and is_fully_defined(data_value): - output_value = np.zeros(output_shape, dtype=data_value.dtype) - if batch_dims == 0: - output_indices_range = int64_array(indices_shape[:-1]) - for output_index in np.ndindex(tuple(output_indices_range)): - indices_tuple = indices_value[output_index] - output_value[output_index] = data_value[tuple(indices_tuple.T)] - else: - batch_dims_range = int64_array(indices_shape[:batch_dims]) - for batch_indices in np.ndindex(tuple(batch_dims_range)): - # compute batch index in output tensor - batch_ind = 0 - num_elements = 1 - for ind in reversed(range(len(batch_dims_range))): - batch_ind += batch_indices[ind] * num_elements - num_elements *= batch_dims_range[ind] - output_indices_range = int64_array(indices_shape[batch_dims:-1]) - for output_index in np.ndindex(tuple(output_indices_range)): - tmp_ind = batch_indices + output_index - indices_tuple = tuple(indices_value[tmp_ind].T) - full_input_ind = batch_indices + indices_tuple - full_output_ind = tuple(np.array([batch_ind]).T) + output_index - output_value[full_output_ind] = data_value[full_input_ind] + output_data = [] + + reshaped_indices = indices_value.reshape(batch_dims_size, -1, indices_shape[-1]) + + reshaped_data = data_value.reshape((batch_dims_size,) + tuple((data_shape[batch_dims:]))) + + for batch_dim in range(reshaped_indices.shape[0]): + for outer_dim in range(reshaped_indices.shape[1]): + gather_index = tuple(reshaped_indices[batch_dim][outer_dim]) + output_data.append(reshaped_data[(batch_dim,) + gather_index]) + output_value = np.asarray(output_data, dtype=data_value.dtype).reshape(output_shape) node.out_port(0).data.set_value(output_value) diff --git a/model-optimizer/unit_tests/extensions/ops/gathernd_test.py b/model-optimizer/unit_tests/extensions/ops/gathernd_test.py index 4ad84bf527150c..f43386247c55be 100644 --- a/model-optimizer/unit_tests/extensions/ops/gathernd_test.py +++ b/model-optimizer/unit_tests/extensions/ops/gathernd_test.py @@ -14,7 +14,7 @@ 'data_data': {'shape': None, 'value': None, 'kind': 'data'}, 'indices': {'kind': 'op'}, 'indices_data': {'shape': None, 'value': None, 'kind': 'data'}, - 'gathernd_node': {'op': 'GatherNDUpdate', 'kind': 'op', 'batch_dims': 0}, + 'gathernd_node': {'op': 'GatherNDUpdate', 'kind': 'op', 'batch_dims': 0, 'version': 'opset8'}, 'output': {'shape': None, 'value': None, 'kind': 'data'}} # graph 1 @@ -82,7 +82,7 @@ [[[0]], [[2]], [[2]]]])}} -output7 = int64_array([[2], [5], [11], [13], [19], [23]]) +output7 = int64_array([[[2], [5], [11]], [[13], [19], [23]]]) # test data for constant folding: gather elements, batch_dims = 2 inputs8 = {'data_data': {'shape': int64_array([2, 3, 4, 2]), @@ -99,12 +99,12 @@ [[[2, 0], [1, 1], [3, 1]], [[1, 1], [2, 0], [2, 0]], [[0, 0], [3, 1], [3, 1]]]])}} -output8 = int64_array([[3, 8, 6], +output8 = int64_array([[[3, 8, 6], [10, 12, 13], - [23, 24, 22], - [29, 28, 32], + [23, 24, 22]], + [[29, 28, 32], [36, 37, 37], - [41, 48, 48]]) + [41, 48, 48]]]) # test data for partial infer: gather slices and batch_dims=2 inputs9 = {'data_data': {'shape': shape_array([dynamic_dimension_value, 40, 4, 9]), 'value': None}, @@ -162,7 +162,7 @@ def test_partial_infer_gather_slice_batch_dims2(self): GatherND.infer(gathernd_node) # prepare reference results - ref_output_shape = int64_array([400, 3, 5, 9]) + ref_output_shape = int64_array([10, 40, 3, 5, 9]) # get the result res_output_shape = graph.node['output']['shape'] From c41f5adaa9e51438121b7d169738f10cf03ae27e Mon Sep 17 00:00:00 2001 From: achetver Date: Wed, 6 Oct 2021 13:04:09 +0300 Subject: [PATCH 03/29] Initial commit for nGraph GatherND_8 operation --- ngraph/core/include/ngraph/op/gather_nd.hpp | 5 +- ngraph/core/include/openvino/op/gather_nd.hpp | 34 +++++- .../include/openvino/opsets/opset8_tbl.hpp | 2 +- ngraph/core/src/op/gather_nd.cpp | 112 ++++++++++++++++++ 4 files changed, 149 insertions(+), 4 deletions(-) diff --git a/ngraph/core/include/ngraph/op/gather_nd.hpp b/ngraph/core/include/ngraph/op/gather_nd.hpp index 9689be8b854b0b..d8d999cac6bc21 100644 --- a/ngraph/core/include/ngraph/op/gather_nd.hpp +++ b/ngraph/core/include/ngraph/op/gather_nd.hpp @@ -4,7 +4,7 @@ #pragma once -#include "ngraph/op/op.hpp" +#include "ngraph/op/util/gather_nd_base.hpp" #include "openvino/op/gather_nd.hpp" namespace ngraph { @@ -12,5 +12,8 @@ namespace op { namespace v5 { using ov::op::v5::GatherND; } // namespace v5 +namespace v8 { +using ov::op::v8::GatherND; +} // namespace v8 } // namespace op } // namespace ngraph diff --git a/ngraph/core/include/openvino/op/gather_nd.hpp b/ngraph/core/include/openvino/op/gather_nd.hpp index 36f23d8970c930..f1c80efc73266f 100644 --- a/ngraph/core/include/openvino/op/gather_nd.hpp +++ b/ngraph/core/include/openvino/op/gather_nd.hpp @@ -11,9 +11,9 @@ namespace op { namespace v5 { /// \brief GatherND operation /// -class OPENVINO_API GatherND : public Op { +class OPENVINO_API GatherND : public Op::util::GatherNDBase { public: - OPENVINO_OP("GatherND", "opset5", op::Op, 5); + OPENVINO_OP("GatherND", "opset5", op::util::GatherBase, 5); BWDCMP_RTTI_DECLARATION; GatherND() = default; @@ -37,5 +37,35 @@ class OPENVINO_API GatherND : public Op { size_t m_batch_dims; }; } // namespace v5 + +namespace v5 { +/// \brief GatherND operation +/// + class OPENVINO_API GatherND : public Op::util::GatherNDBase { + public: + OPENVINO_OP("GatherND", "opset8", op::util::GatherBase, 8); + BWDCMP_RTTI_DECLARATION; + GatherND() = default; + + /// \brief Constructs a GatherND operation. + /// + /// \param data Node producing data that are gathered + /// \param indices Node producing indices by which the operation gathers elements + /// or slices from data + /// \param batch_dims Specifies a number of batch dimensions + GatherND(const Output& data, const Output& indices, const size_t batch_dims = 0); + + void validate_and_infer_types() override; + bool visit_attributes(AttributeVisitor& visitor) override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + + size_t get_batch_dims() const { + return m_batch_dims; + } + + private: + size_t m_batch_dims; +}; +} // namespace v8 } // namespace op } // namespace ov diff --git a/ngraph/core/include/openvino/opsets/opset8_tbl.hpp b/ngraph/core/include/openvino/opsets/opset8_tbl.hpp index 0c686bdab09213..e3e9c0a5f8383e 100644 --- a/ngraph/core/include/openvino/opsets/opset8_tbl.hpp +++ b/ngraph/core/include/openvino/opsets/opset8_tbl.hpp @@ -144,7 +144,6 @@ _OPENVINO_OP_REG(SoftPlus, ov::op::v4) _OPENVINO_OP_REG(Swish, ov::op::v4) // New operations added in opset5 -_OPENVINO_OP_REG(GatherND, ov::op::v5) _OPENVINO_OP_REG(GRUSequence, ov::op::v5) _OPENVINO_OP_REG(HSigmoid, ov::op::v5) _OPENVINO_OP_REG(LogSoftmax, ov::op::v5) @@ -175,6 +174,7 @@ _OPENVINO_OP_REG(Roll, ov::op::v7) // New operations added in opset8 _OPENVINO_OP_REG(Gather, ov::op::v8) +_OPENVINO_OP_REG(GatherND, ov::op::v8) _OPENVINO_OP_REG(AdaptiveAvgPool, ov::op::v8) _OPENVINO_OP_REG(AdaptiveMaxPool, ov::op::v8) _OPENVINO_OP_REG(DeformableConvolution, ov::op::v8) diff --git a/ngraph/core/src/op/gather_nd.cpp b/ngraph/core/src/op/gather_nd.cpp index d4c54e5066e3d9..d084279e2ce627 100644 --- a/ngraph/core/src/op/gather_nd.cpp +++ b/ngraph/core/src/op/gather_nd.cpp @@ -122,3 +122,115 @@ shared_ptr op::v5::GatherND::clone_with_new_inputs(const OutputVector& new check_new_args_count(this, new_args); return make_shared(new_args.at(0), new_args.at(1), m_batch_dims); } + +// ------------------------------ V8 ------------------------------ +BWDCMP_RTTI_DEFINITION(op::v8::GatherND); + +op::v8::GatherND::GatherND(const Output& data, const Output& indices, const size_t batch_dims) + : Op({data, indices}), + m_batch_dims(batch_dims) { + constructor_validate_and_infer_types(); +} + +void op::v8::GatherND::validate_and_infer_types() { + NGRAPH_OP_SCOPE(v8_GatherND_validate_and_infer_types); + // check types of input tensors + const auto& data_type = get_input_element_type(0); + const auto& indices_type = get_input_element_type(1); + + NODE_VALIDATION_CHECK(this, + indices_type.is_integral_number(), + "The indices type is expected to be an integer type. Got: ", + indices_type); + + // check ranks of input tensors + const auto& data_pshape = get_input_partial_shape(0); + const auto& indices_pshape = get_input_partial_shape(1); + + if (data_pshape.rank().is_static()) { + NODE_VALIDATION_CHECK(this, data_pshape.rank().get_length() > 0, "Data rank must be at least 1."); + + NODE_VALIDATION_CHECK(this, + data_pshape.rank().get_length() > static_cast(m_batch_dims), + "Number of batch dimensions must not exceed a rank of data."); + } + + if (indices_pshape.rank().is_static()) { + NODE_VALIDATION_CHECK(this, indices_pshape.rank().get_length() > 0, "Indices rank must be at least 1."); + + NODE_VALIDATION_CHECK(this, + indices_pshape.rank().get_length() > static_cast(m_batch_dims), + "Number of batch dimensions must not exceed a rank of indices."); + } + + if (data_pshape.rank().is_static() && indices_pshape.rank().is_static()) { + // check that batch dimensions of data and indices are the same + for (size_t batch_dim = 0; batch_dim < m_batch_dims; batch_dim++) { + if (data_pshape[batch_dim].is_static() && indices_pshape[batch_dim].is_static()) { + NODE_VALIDATION_CHECK(this, + data_pshape[batch_dim].get_length() == indices_pshape[batch_dim].get_length(), + "Batch dimensions of data and indices must be the same."); + } + } + + if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { + NODE_VALIDATION_CHECK( + this, + static_cast(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() + + m_batch_dims) <= data_pshape.rank().get_length(), + "Length of a tuple with indices must not exceed a rank of data tensor " + "excluding " + "batch dimensions."); + } + } + + // set output shape + set_output_size(1); + if (data_pshape.rank().is_static() && indices_pshape.rank().is_static() && + indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { + auto indices_tuple_length = indices_pshape[indices_pshape.rank().get_length() - 1].get_length(); + int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - m_batch_dims; + int64_t output_indices_length = indices_pshape.rank().get_length() - m_batch_dims - 1; + auto output_rank = output_indices_length + slice_length; + size_t delta_output_rank = 0; + if (m_batch_dims > 0) { + delta_output_rank = 1; + } + std::vector output_shape(output_rank + delta_output_rank); + if (m_batch_dims > 0) { + output_shape[0] = 1; + for (size_t dim = 0; dim < m_batch_dims; dim++) { + if (data_pshape[dim].is_static()) { + output_shape[0] *= data_pshape[dim].get_length(); + } else if (indices_pshape[dim].is_static()) { + output_shape[0] *= indices_pshape[dim].get_length(); + } else { + output_shape[0] = Dimension::dynamic(); + break; + } + } + } + for (int64_t dim = 0; dim < output_indices_length; dim++) { + output_shape[dim + delta_output_rank] = indices_pshape[dim + m_batch_dims]; + } + for (int64_t dim = 0; dim < slice_length; dim++) { + output_shape[output_indices_length + dim + delta_output_rank] = + data_pshape[m_batch_dims + indices_tuple_length + dim]; + } + set_output_type(0, data_type, ov::PartialShape(output_shape)); + } else { + set_output_type(0, data_type, ov::PartialShape::dynamic()); + } +} + +bool op::v8::GatherND::visit_attributes(AttributeVisitor& visitor) { + NGRAPH_OP_SCOPE(v8_GatherND_visit_attributes); + visitor.on_attribute("batch_dims", m_batch_dims); + return true; +} + +shared_ptr op::v8::GatherND::clone_with_new_inputs(const OutputVector& new_args) const { + NGRAPH_OP_SCOPE(v8_GatherND_clone_with_new_inputs); + check_new_args_count(this, new_args); + return make_shared(new_args.at(0), new_args.at(1), m_batch_dims); +} \ No newline at end of file From c50914fba6fc2efaf7b13f0083f73e5cd1f00ea7 Mon Sep 17 00:00:00 2001 From: achetver Date: Fri, 8 Oct 2021 13:08:31 +0300 Subject: [PATCH 04/29] Add GatherNDBase class implementation --- .../openvino/op/util/gather_nd_base.hpp | 41 ++++++ ngraph/core/src/op/util/gather_nd_base.cpp | 118 ++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 ngraph/core/include/openvino/op/util/gather_nd_base.hpp create mode 100644 ngraph/core/src/op/util/gather_nd_base.cpp diff --git a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp new file mode 100644 index 00000000000000..bea6e526131c63 --- /dev/null +++ b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp @@ -0,0 +1,41 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { +namespace util { +/// \brief GatherNDBase basic class for GatherND v5 and v8 +class OPENVINO_API GatherNDBase : public Op { +public: + OPENVINO_OP("GatherNDBase", "util"); + BWDCMP_RTTI_DECLARATION; + GatherNDBase() = default; + + /// \brief Constructs a GatherND operation. + /// + /// \param data Node producing data that are gathered + /// \param indices Node producing indices by which the operation gathers elements + /// or slices from data + /// \param batch_dims Specifies a number of batch dimensions + GatherNDBase(const Output& data, const Output& indices, const size_t batch_dims = 0); + + void validate_and_infer_types() override; + bool visit_attributes(AttributeVisitor& visitor) override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + + size_t get_batch_dims() const { + return m_batch_dims; + } + +private: + size_t m_batch_dims; +}; +}; +} // namespace util +} // namespace op +} // namespace ov diff --git a/ngraph/core/src/op/util/gather_nd_base.cpp b/ngraph/core/src/op/util/gather_nd_base.cpp new file mode 100644 index 00000000000000..afab706f59e9c0 --- /dev/null +++ b/ngraph/core/src/op/util/gather_nd_base.cpp @@ -0,0 +1,118 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ngraph/op/util/gather_base.hpp" + +#include + +#include "itt.hpp" +#include "ngraph/op/concat.hpp" +#include "ngraph/op/constant.hpp" +#include "ngraph/op/squeeze.hpp" +#include "ngraph/runtime/host_tensor.hpp" +#include "ngraph/runtime/reference/gather_nd.hpp" +#include "ngraph/shape.hpp" + +using namespace std; + + +BWDCMP_RTTI_DEFINITION(ov::op::util::GatherNDBase); + +ov::op::util::GatherNDBase::GatherNDBase(const Output& data, const Output& indices, const size_t batch_dims) + : Op({data, indices}), + m_batch_dims(batch_dims) { + constructor_validate_and_infer_types(); +} + +void ov::op::util::GatherNDBase::validate_and_infer_types() { + NGRAPH_OP_SCOPE(util_GatherNDBase_validate_and_infer_types); + // check types of input tensors + const auto& data_type = get_input_element_type(0); + const auto& indices_type = get_input_element_type(1); + + NODE_VALIDATION_CHECK(this, + indices_type.is_integral_number(), + "The indices type is expected to be an integer type. Got: ", + indices_type); + + // check ranks of input tensors + const auto& data_pshape = get_input_partial_shape(0); + const auto& indices_pshape = get_input_partial_shape(1); + + if (data_pshape.rank().is_static()) { + NODE_VALIDATION_CHECK(this, data_pshape.rank().get_length() > 0, "Data rank must be at least 1."); + + NODE_VALIDATION_CHECK(this, + data_pshape.rank().get_length() > static_cast(m_batch_dims), + "Number of batch dimensions must not exceed a rank of data."); + } + + if (indices_pshape.rank().is_static()) { + NODE_VALIDATION_CHECK(this, indices_pshape.rank().get_length() > 0, "Indices rank must be at least 1."); + + NODE_VALIDATION_CHECK(this, + indices_pshape.rank().get_length() > static_cast(m_batch_dims), + "Number of batch dimensions must not exceed a rank of indices."); + } + + if (data_pshape.rank().is_static() && indices_pshape.rank().is_static()) { + // check that batch dimensions of data and indices are the same + for (size_t batch_dim = 0; batch_dim < m_batch_dims; batch_dim++) { + if (data_pshape[batch_dim].is_static() && indices_pshape[batch_dim].is_static()) { + NODE_VALIDATION_CHECK(this, + data_pshape[batch_dim].get_length() == indices_pshape[batch_dim].get_length(), + "Batch dimensions of data and indices must be the same."); + } + } + + if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { + NODE_VALIDATION_CHECK( + this, + static_cast(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() + + m_batch_dims) <= data_pshape.rank().get_length(), + "Length of a tuple with indices must not exceed a rank of data tensor " + "excluding " + "batch dimensions."); + } + } + + // set output shape + set_output_size(1); + if (data_pshape.rank().is_static() && indices_pshape.rank().is_static() && + indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { + auto indices_tuple_length = indices_pshape[indices_pshape.rank().get_length() - 1].get_length(); + int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - m_batch_dims; + int64_t output_indices_length = indices_pshape.rank().get_length() - m_batch_dims - 1; + auto output_rank = output_indices_length + slice_length; + size_t delta_output_rank = 0; + if (m_batch_dims > 0) { + delta_output_rank = 1; + } + std::vector output_shape(output_rank + delta_output_rank); + if (m_batch_dims > 0) { + output_shape[0] = 1; + for (size_t dim = 0; dim < m_batch_dims; dim++) { + if (data_pshape[dim].is_static()) { + output_shape[0] *= data_pshape[dim].get_length(); + } else if (indices_pshape[dim].is_static()) { + output_shape[0] *= indices_pshape[dim].get_length(); + } else { + output_shape[0] = Dimension::dynamic(); + break; + } + } + } + for (int64_t dim = 0; dim < output_indices_length; dim++) { + output_shape[dim + delta_output_rank] = indices_pshape[dim + m_batch_dims]; + } + for (int64_t dim = 0; dim < slice_length; dim++) { + output_shape[output_indices_length + dim + delta_output_rank] = + data_pshape[m_batch_dims + indices_tuple_length + dim]; + } + set_output_type(0, data_type, ov::PartialShape(output_shape)); + } else { + set_output_type(0, data_type, ov::PartialShape::dynamic()); + } +} + From 212f72e50c6c0422f197761f4e57744e851ff4b4 Mon Sep 17 00:00:00 2001 From: achetver Date: Fri, 8 Oct 2021 13:42:41 +0300 Subject: [PATCH 05/29] Fix base class errors --- ngraph/core/include/openvino/op/gather_nd.hpp | 13 ++++++------- .../include/openvino/op/util/gather_nd_base.hpp | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/ngraph/core/include/openvino/op/gather_nd.hpp b/ngraph/core/include/openvino/op/gather_nd.hpp index f1c80efc73266f..cffbabd7bbb8b5 100644 --- a/ngraph/core/include/openvino/op/gather_nd.hpp +++ b/ngraph/core/include/openvino/op/gather_nd.hpp @@ -4,16 +4,15 @@ #pragma once -#include "openvino/op/op.hpp" +#include "openvino/op/util/gather_base.hpp" namespace ov { namespace op { namespace v5 { /// \brief GatherND operation -/// -class OPENVINO_API GatherND : public Op::util::GatherNDBase { +class OPENVINO_API GatherND : public op::util::GatherNDBase { public: - OPENVINO_OP("GatherND", "opset5", op::util::GatherBase, 5); + OPENVINO_OP("GatherND", "opset5", op::util::GatherNDBase, 5); BWDCMP_RTTI_DECLARATION; GatherND() = default; @@ -38,12 +37,12 @@ class OPENVINO_API GatherND : public Op::util::GatherNDBase { }; } // namespace v5 -namespace v5 { +namespace v8 { /// \brief GatherND operation /// - class OPENVINO_API GatherND : public Op::util::GatherNDBase { + class OPENVINO_API GatherND : public op::util::GatherNDBase { public: - OPENVINO_OP("GatherND", "opset8", op::util::GatherBase, 8); + OPENVINO_OP("GatherND", "opset8", op::util::GatherNDBase, 8); BWDCMP_RTTI_DECLARATION; GatherND() = default; diff --git a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp index bea6e526131c63..17d72a645e7d22 100644 --- a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp +++ b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp @@ -33,7 +33,7 @@ class OPENVINO_API GatherNDBase : public Op { } private: - size_t m_batch_dims; + size_t m_batch_dims = 0; }; }; } // namespace util From 49ce0721ba4102a98fa85abaf22b0c120ac9ec46 Mon Sep 17 00:00:00 2001 From: achetver Date: Fri, 8 Oct 2021 13:51:58 +0300 Subject: [PATCH 06/29] Add missrd header --- .../include/ngraph/op/util/gather_nd_base.hpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 ngraph/core/include/ngraph/op/util/gather_nd_base.hpp diff --git a/ngraph/core/include/ngraph/op/util/gather_nd_base.hpp b/ngraph/core/include/ngraph/op/util/gather_nd_base.hpp new file mode 100644 index 00000000000000..6585dac61bb401 --- /dev/null +++ b/ngraph/core/include/ngraph/op/util/gather_nd_base.hpp @@ -0,0 +1,16 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/op/op.hpp" +#include "openvino/op/util/gather_nd_base.hpp" + +namespace ngraph { +namespace op { +namespace util { +using ov::op::util::GatherNDBase; +} // namespace util +} // namespace op +} // namespace ngraph From f5e9ea6bf64ac662447d968542a9c6db0f82505d Mon Sep 17 00:00:00 2001 From: achetver Date: Fri, 8 Oct 2021 13:57:22 +0300 Subject: [PATCH 07/29] Update base class --- ngraph/core/include/openvino/op/util/gather_nd_base.hpp | 7 +++---- ngraph/core/src/op/util/gather_nd_base.cpp | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp index 17d72a645e7d22..f8e81c36648198 100644 --- a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp +++ b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp @@ -22,11 +22,11 @@ class OPENVINO_API GatherNDBase : public Op { /// \param indices Node producing indices by which the operation gathers elements /// or slices from data /// \param batch_dims Specifies a number of batch dimensions - GatherNDBase(const Output& data, const Output& indices, const size_t batch_dims = 0); + GatherNDBase(const Output& data, + const Output& indices, + const size_t batch_dims = 0); void validate_and_infer_types() override; - bool visit_attributes(AttributeVisitor& visitor) override; - std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; size_t get_batch_dims() const { return m_batch_dims; @@ -35,7 +35,6 @@ class OPENVINO_API GatherNDBase : public Op { private: size_t m_batch_dims = 0; }; -}; } // namespace util } // namespace op } // namespace ov diff --git a/ngraph/core/src/op/util/gather_nd_base.cpp b/ngraph/core/src/op/util/gather_nd_base.cpp index afab706f59e9c0..f9770db516b3bc 100644 --- a/ngraph/core/src/op/util/gather_nd_base.cpp +++ b/ngraph/core/src/op/util/gather_nd_base.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ngraph/op/util/gather_base.hpp" +#include "ngraph/op/util/gather_nd_base.hpp" #include From f134c9505d58baffa82e81821f0a0725ab7386f5 Mon Sep 17 00:00:00 2001 From: achetver Date: Fri, 8 Oct 2021 18:41:05 +0300 Subject: [PATCH 08/29] Update GatherND_8 implementation --- ngraph/core/include/openvino/op/gather_nd.hpp | 16 +--------------- .../include/openvino/op/util/gather_nd_base.hpp | 2 +- ngraph/core/src/op/gather_nd.cpp | 8 +++----- ngraph/core/src/op/util/gather_nd_base.cpp | 3 --- 4 files changed, 5 insertions(+), 24 deletions(-) diff --git a/ngraph/core/include/openvino/op/gather_nd.hpp b/ngraph/core/include/openvino/op/gather_nd.hpp index cffbabd7bbb8b5..a6acab7297bd0e 100644 --- a/ngraph/core/include/openvino/op/gather_nd.hpp +++ b/ngraph/core/include/openvino/op/gather_nd.hpp @@ -4,7 +4,7 @@ #pragma once -#include "openvino/op/util/gather_base.hpp" +#include "openvino/op/util/gather_nd_base.hpp" namespace ov { namespace op { @@ -27,13 +27,6 @@ class OPENVINO_API GatherND : public op::util::GatherNDBase { void validate_and_infer_types() override; bool visit_attributes(AttributeVisitor& visitor) override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; - - size_t get_batch_dims() const { - return m_batch_dims; - } - -private: - size_t m_batch_dims; }; } // namespace v5 @@ -57,13 +50,6 @@ namespace v8 { void validate_and_infer_types() override; bool visit_attributes(AttributeVisitor& visitor) override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; - - size_t get_batch_dims() const { - return m_batch_dims; - } - - private: - size_t m_batch_dims; }; } // namespace v8 } // namespace op diff --git a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp index f8e81c36648198..a1e67eb8857d4a 100644 --- a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp +++ b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp @@ -32,7 +32,7 @@ class OPENVINO_API GatherNDBase : public Op { return m_batch_dims; } -private: +protected: size_t m_batch_dims = 0; }; } // namespace util diff --git a/ngraph/core/src/op/gather_nd.cpp b/ngraph/core/src/op/gather_nd.cpp index d084279e2ce627..125d97d27f89bc 100644 --- a/ngraph/core/src/op/gather_nd.cpp +++ b/ngraph/core/src/op/gather_nd.cpp @@ -15,8 +15,7 @@ using namespace ngraph; BWDCMP_RTTI_DEFINITION(op::v5::GatherND); op::v5::GatherND::GatherND(const Output& data, const Output& indices, const size_t batch_dims) - : Op({data, indices}), - m_batch_dims(batch_dims) { + : GatherNDBase(data, indices, batch_dims) { constructor_validate_and_infer_types(); } @@ -127,8 +126,7 @@ shared_ptr op::v5::GatherND::clone_with_new_inputs(const OutputVector& new BWDCMP_RTTI_DEFINITION(op::v8::GatherND); op::v8::GatherND::GatherND(const Output& data, const Output& indices, const size_t batch_dims) - : Op({data, indices}), - m_batch_dims(batch_dims) { + : GatherNDBase(data, indices, batch_dims) { constructor_validate_and_infer_types(); } @@ -233,4 +231,4 @@ shared_ptr op::v8::GatherND::clone_with_new_inputs(const OutputVector& new NGRAPH_OP_SCOPE(v8_GatherND_clone_with_new_inputs); check_new_args_count(this, new_args); return make_shared(new_args.at(0), new_args.at(1), m_batch_dims); -} \ No newline at end of file +} diff --git a/ngraph/core/src/op/util/gather_nd_base.cpp b/ngraph/core/src/op/util/gather_nd_base.cpp index f9770db516b3bc..d09b4a5a1bc34c 100644 --- a/ngraph/core/src/op/util/gather_nd_base.cpp +++ b/ngraph/core/src/op/util/gather_nd_base.cpp @@ -6,17 +6,14 @@ #include -#include "itt.hpp" #include "ngraph/op/concat.hpp" #include "ngraph/op/constant.hpp" #include "ngraph/op/squeeze.hpp" #include "ngraph/runtime/host_tensor.hpp" -#include "ngraph/runtime/reference/gather_nd.hpp" #include "ngraph/shape.hpp" using namespace std; - BWDCMP_RTTI_DEFINITION(ov::op::util::GatherNDBase); ov::op::util::GatherNDBase::GatherNDBase(const Output& data, const Output& indices, const size_t batch_dims) From 58ac4cfd7efccf8aa2dda75b1ac59ce8857d723c Mon Sep 17 00:00:00 2001 From: achetver Date: Mon, 11 Oct 2021 08:31:52 +0300 Subject: [PATCH 09/29] Fix codestyle --- ngraph/core/include/openvino/op/gather_nd.hpp | 4 +- .../openvino/op/util/gather_nd_base.hpp | 6 +- ngraph/core/src/op/gather_nd.cpp | 16 ++-- ngraph/core/src/op/util/gather_nd_base.cpp | 96 +------------------ 4 files changed, 13 insertions(+), 109 deletions(-) diff --git a/ngraph/core/include/openvino/op/gather_nd.hpp b/ngraph/core/include/openvino/op/gather_nd.hpp index a6acab7297bd0e..19f0207d616348 100644 --- a/ngraph/core/include/openvino/op/gather_nd.hpp +++ b/ngraph/core/include/openvino/op/gather_nd.hpp @@ -33,8 +33,8 @@ class OPENVINO_API GatherND : public op::util::GatherNDBase { namespace v8 { /// \brief GatherND operation /// - class OPENVINO_API GatherND : public op::util::GatherNDBase { - public: +class OPENVINO_API GatherND : public op::util::GatherNDBase { +public: OPENVINO_OP("GatherND", "opset8", op::util::GatherNDBase, 8); BWDCMP_RTTI_DECLARATION; GatherND() = default; diff --git a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp index a1e67eb8857d4a..44b923954b64d4 100644 --- a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp +++ b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp @@ -22,11 +22,7 @@ class OPENVINO_API GatherNDBase : public Op { /// \param indices Node producing indices by which the operation gathers elements /// or slices from data /// \param batch_dims Specifies a number of batch dimensions - GatherNDBase(const Output& data, - const Output& indices, - const size_t batch_dims = 0); - - void validate_and_infer_types() override; + GatherNDBase(const Output& data, const Output& indices, const size_t batch_dims = 0); size_t get_batch_dims() const { return m_batch_dims; diff --git a/ngraph/core/src/op/gather_nd.cpp b/ngraph/core/src/op/gather_nd.cpp index 125d97d27f89bc..d92f84719cd736 100644 --- a/ngraph/core/src/op/gather_nd.cpp +++ b/ngraph/core/src/op/gather_nd.cpp @@ -126,7 +126,7 @@ shared_ptr op::v5::GatherND::clone_with_new_inputs(const OutputVector& new BWDCMP_RTTI_DEFINITION(op::v8::GatherND); op::v8::GatherND::GatherND(const Output& data, const Output& indices, const size_t batch_dims) - : GatherNDBase(data, indices, batch_dims) { + : GatherNDBase(data, indices, batch_dims) { constructor_validate_and_infer_types(); } @@ -173,12 +173,12 @@ void op::v8::GatherND::validate_and_infer_types() { if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { NODE_VALIDATION_CHECK( - this, - static_cast(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() + - m_batch_dims) <= data_pshape.rank().get_length(), - "Length of a tuple with indices must not exceed a rank of data tensor " - "excluding " - "batch dimensions."); + this, + static_cast(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() + + m_batch_dims) <= data_pshape.rank().get_length(), + "Length of a tuple with indices must not exceed a rank of data tensor " + "excluding " + "batch dimensions."); } } @@ -213,7 +213,7 @@ void op::v8::GatherND::validate_and_infer_types() { } for (int64_t dim = 0; dim < slice_length; dim++) { output_shape[output_indices_length + dim + delta_output_rank] = - data_pshape[m_batch_dims + indices_tuple_length + dim]; + data_pshape[m_batch_dims + indices_tuple_length + dim]; } set_output_type(0, data_type, ov::PartialShape(output_shape)); } else { diff --git a/ngraph/core/src/op/util/gather_nd_base.cpp b/ngraph/core/src/op/util/gather_nd_base.cpp index d09b4a5a1bc34c..f00d0510d1acbc 100644 --- a/ngraph/core/src/op/util/gather_nd_base.cpp +++ b/ngraph/core/src/op/util/gather_nd_base.cpp @@ -17,99 +17,7 @@ using namespace std; BWDCMP_RTTI_DEFINITION(ov::op::util::GatherNDBase); ov::op::util::GatherNDBase::GatherNDBase(const Output& data, const Output& indices, const size_t batch_dims) - : Op({data, indices}), - m_batch_dims(batch_dims) { + : Op({data, indices}), + m_batch_dims(batch_dims) { constructor_validate_and_infer_types(); } - -void ov::op::util::GatherNDBase::validate_and_infer_types() { - NGRAPH_OP_SCOPE(util_GatherNDBase_validate_and_infer_types); - // check types of input tensors - const auto& data_type = get_input_element_type(0); - const auto& indices_type = get_input_element_type(1); - - NODE_VALIDATION_CHECK(this, - indices_type.is_integral_number(), - "The indices type is expected to be an integer type. Got: ", - indices_type); - - // check ranks of input tensors - const auto& data_pshape = get_input_partial_shape(0); - const auto& indices_pshape = get_input_partial_shape(1); - - if (data_pshape.rank().is_static()) { - NODE_VALIDATION_CHECK(this, data_pshape.rank().get_length() > 0, "Data rank must be at least 1."); - - NODE_VALIDATION_CHECK(this, - data_pshape.rank().get_length() > static_cast(m_batch_dims), - "Number of batch dimensions must not exceed a rank of data."); - } - - if (indices_pshape.rank().is_static()) { - NODE_VALIDATION_CHECK(this, indices_pshape.rank().get_length() > 0, "Indices rank must be at least 1."); - - NODE_VALIDATION_CHECK(this, - indices_pshape.rank().get_length() > static_cast(m_batch_dims), - "Number of batch dimensions must not exceed a rank of indices."); - } - - if (data_pshape.rank().is_static() && indices_pshape.rank().is_static()) { - // check that batch dimensions of data and indices are the same - for (size_t batch_dim = 0; batch_dim < m_batch_dims; batch_dim++) { - if (data_pshape[batch_dim].is_static() && indices_pshape[batch_dim].is_static()) { - NODE_VALIDATION_CHECK(this, - data_pshape[batch_dim].get_length() == indices_pshape[batch_dim].get_length(), - "Batch dimensions of data and indices must be the same."); - } - } - - if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { - NODE_VALIDATION_CHECK( - this, - static_cast(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() + - m_batch_dims) <= data_pshape.rank().get_length(), - "Length of a tuple with indices must not exceed a rank of data tensor " - "excluding " - "batch dimensions."); - } - } - - // set output shape - set_output_size(1); - if (data_pshape.rank().is_static() && indices_pshape.rank().is_static() && - indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { - auto indices_tuple_length = indices_pshape[indices_pshape.rank().get_length() - 1].get_length(); - int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - m_batch_dims; - int64_t output_indices_length = indices_pshape.rank().get_length() - m_batch_dims - 1; - auto output_rank = output_indices_length + slice_length; - size_t delta_output_rank = 0; - if (m_batch_dims > 0) { - delta_output_rank = 1; - } - std::vector output_shape(output_rank + delta_output_rank); - if (m_batch_dims > 0) { - output_shape[0] = 1; - for (size_t dim = 0; dim < m_batch_dims; dim++) { - if (data_pshape[dim].is_static()) { - output_shape[0] *= data_pshape[dim].get_length(); - } else if (indices_pshape[dim].is_static()) { - output_shape[0] *= indices_pshape[dim].get_length(); - } else { - output_shape[0] = Dimension::dynamic(); - break; - } - } - } - for (int64_t dim = 0; dim < output_indices_length; dim++) { - output_shape[dim + delta_output_rank] = indices_pshape[dim + m_batch_dims]; - } - for (int64_t dim = 0; dim < slice_length; dim++) { - output_shape[output_indices_length + dim + delta_output_rank] = - data_pshape[m_batch_dims + indices_tuple_length + dim]; - } - set_output_type(0, data_type, ov::PartialShape(output_shape)); - } else { - set_output_type(0, data_type, ov::PartialShape::dynamic()); - } -} - From 214095db0f67d8a1d43ff8b5b8fb21a7456716a7 Mon Sep 17 00:00:00 2001 From: achetver Date: Mon, 11 Oct 2021 14:00:18 +0300 Subject: [PATCH 10/29] Fix wrong rank --- ngraph/core/src/op/gather_nd.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ngraph/core/src/op/gather_nd.cpp b/ngraph/core/src/op/gather_nd.cpp index d92f84719cd736..90dd5be24a20b6 100644 --- a/ngraph/core/src/op/gather_nd.cpp +++ b/ngraph/core/src/op/gather_nd.cpp @@ -192,18 +192,18 @@ void op::v8::GatherND::validate_and_infer_types() { auto output_rank = output_indices_length + slice_length; size_t delta_output_rank = 0; if (m_batch_dims > 0) { - delta_output_rank = 1; + delta_output_rank = m_batch_dims; } std::vector output_shape(output_rank + delta_output_rank); if (m_batch_dims > 0) { - output_shape[0] = 1; for (size_t dim = 0; dim < m_batch_dims; dim++) { + output_shape[dim] = 1; if (data_pshape[dim].is_static()) { - output_shape[0] *= data_pshape[dim].get_length(); + output_shape[dim] = data_pshape[dim].get_length(); } else if (indices_pshape[dim].is_static()) { - output_shape[0] *= indices_pshape[dim].get_length(); + output_shape[dim] = indices_pshape[dim].get_length(); } else { - output_shape[0] = Dimension::dynamic(); + output_shape[dim] = Dimension::dynamic(); break; } } From 115241979d5d16c2257363d88bc6463e8e6cf890 Mon Sep 17 00:00:00 2001 From: achetver Date: Mon, 11 Oct 2021 17:12:36 +0300 Subject: [PATCH 11/29] Implement tests for gatherND_8 shape inference function --- ngraph/test/type_prop/gather_nd.cpp | 133 ++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/ngraph/test/type_prop/gather_nd.cpp b/ngraph/test/type_prop/gather_nd.cpp index e05c23a52b3dc8..27f551e1ff8b1a 100644 --- a/ngraph/test/type_prop/gather_nd.cpp +++ b/ngraph/test/type_prop/gather_nd.cpp @@ -316,3 +316,136 @@ TEST(type_prop, gather_nd_fail_indices_element_type) { FAIL() << "Deduced type check failed for unexpected reason"; } } + + +// ------------------------------ V5 ------------------------------ + +TEST(type_prop, gather_nd_8_slices_from_4d_batch_dims0) { + Shape params_shape{2, 3, 11, 12}; + Shape indices_shape{2, 3, 2}; + Shape out_shape{2, 3, 11, 12}; + auto P = make_shared(element::f32, params_shape); + auto I = make_shared(element::i32, indices_shape); + auto G5 = make_shared(P, I, 0); + ASSERT_EQ(G5->get_element_type(), element::f32); + ASSERT_EQ(G5->get_shape(), out_shape); +} + +TEST(type_prop, gather_nd_8_scalars_from_4d_batch_dims2) { + Shape params_shape{2, 3, 11, 12}; + Shape indices_shape{2, 3, 2}; + Shape out_shape{2, 3}; + auto P = make_shared(element::f32, params_shape); + auto I = make_shared(element::i32, indices_shape); + auto G5 = make_shared(P, I, 2); + ASSERT_EQ(G5->get_element_type(), element::f32); + ASSERT_EQ(G5->get_shape(), out_shape); +} + +TEST(type_prop, gather_nd_8_slices_from_5d_batch_dims2) { + Shape params_shape{7, 5, 11, 12, 32}; + Shape indices_shape{7, 5, 3, 1}; + Shape out_shape{7, 5, 3, 12, 32}; + auto P = make_shared(element::f32, params_shape); + auto I = make_shared(element::i32, indices_shape); + auto G5 = make_shared(P, I, 2); + ASSERT_EQ(G5->get_element_type(), element::f32); + ASSERT_EQ(G5->get_shape(), out_shape); +} + +TEST(type_prop, gather_nd_8_batch_dim2_with_dyn_dim) { + PartialShape params_shape{7, Dimension::dynamic(), 11, 12, 32}; + Shape indices_shape{7, 5, 3, 1}; + Shape out_shape{7, 5, 3, 12, 32}; + auto P = make_shared(element::f32, params_shape); + auto I = make_shared(element::i32, indices_shape); + auto G5 = make_shared(P, I, 2); + ASSERT_EQ(G5->get_element_type(), element::f32); + ASSERT_EQ(G5->get_shape(), out_shape); +} + +TEST(type_prop, gather_nd_8_batch_dim2_with_dyn_dim2) { + PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, 32}; + Shape indices_shape{7, 5, 3, 1}; + Shape out_shape{7, 5, 3, 12, 32}; + auto P = make_shared(element::f32, params_shape); + auto I = make_shared(element::i32, indices_shape); + auto G5 = make_shared(P, I, 2); + ASSERT_EQ(G5->get_element_type(), element::f32); + ASSERT_EQ(G5->get_shape(), out_shape); +} + +TEST(type_prop, gather_nd_8_batch_dim2_with_dyn_dim3) { + PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()}; + Shape indices_shape{7, 5, 3, 1}; + PartialShape out_shape{7, 5, 3, 12, Dimension::dynamic()}; + auto P = make_shared(element::f32, params_shape); + auto I = make_shared(element::i32, indices_shape); + auto G5 = make_shared(P, I, 2); + ASSERT_EQ(G5->get_element_type(), element::f32); + ASSERT_TRUE(G5->get_output_partial_shape(0).same_scheme(out_shape)); +} + +TEST(type_prop, gather_nd_8_batch_dim0_with_dyn_ind_dim) { + PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()}; + PartialShape indices_shape{7, 5, 3, Dimension::dynamic()}; + auto P = make_shared(element::f32, params_shape); + auto I = make_shared(element::i32, indices_shape); + auto G5 = make_shared(P, I, 0); + ASSERT_EQ(G5->get_element_type(), element::f32); + ASSERT_TRUE(G5->get_output_partial_shape(0).same_scheme(PartialShape::dynamic())); +} + +TEST(type_prop, gather_nd_8_fail_batch_dims_greater_indices_rank) { + Shape params_shape{2, 3, 4, 5}; + Shape indices_shape{2, 1}; + auto P = make_shared(element::f32, params_shape); + auto I = make_shared(element::i32, indices_shape); + + try { + auto G5 = make_shared(P, I, 3); + // Should have thrown, so fail if it didn't + FAIL() << "Incorrect indices rank"; + } catch (const NodeValidationFailure& error) { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Number of batch dimensions must not exceed a rank of indices.")); + } catch (...) { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, gather_nd_8_fail_unequal_batch_dims) { + Shape params_shape{2, 3, 4, 5}; + Shape indices_shape{2, 1, 4}; + auto P = make_shared(element::f32, params_shape); + auto I = make_shared(element::i32, indices_shape); + + try { + auto G5 = make_shared(P, I, 2); + // Should have thrown, so fail if it didn't + FAIL() << "Incorrect indices rank"; + } catch (const NodeValidationFailure& error) { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch dimensions of data and indices must be the same.")); + } catch (...) { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, gather_nd_8_fail_indices_tuple_greater_data_rank_batch_dims2) { + Shape params_shape{2, 1, 4, 5}; + Shape indices_shape{2, 1, 5, 3}; + auto P = make_shared(element::f32, params_shape); + auto I = make_shared(element::i32, indices_shape); + + try { + auto G5 = make_shared(P, I, 2); + // Should have thrown, so fail if it didn't + FAIL() << "Incorrect indices rank"; + } catch (const NodeValidationFailure& error) { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Length of a tuple with indices must not exceed a rank of " + "data tensor excluding batch dimensions.")); + } catch (...) { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} From 7380d4489380ea21a1ac6d71fbb85cc89e7d97d3 Mon Sep 17 00:00:00 2001 From: Chetverikov Date: Mon, 11 Oct 2021 17:47:50 +0300 Subject: [PATCH 12/29] fix codestyle --- ngraph/test/type_prop/gather_nd.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ngraph/test/type_prop/gather_nd.cpp b/ngraph/test/type_prop/gather_nd.cpp index 27f551e1ff8b1a..ab3ab01d2b8c25 100644 --- a/ngraph/test/type_prop/gather_nd.cpp +++ b/ngraph/test/type_prop/gather_nd.cpp @@ -317,8 +317,7 @@ TEST(type_prop, gather_nd_fail_indices_element_type) { } } - -// ------------------------------ V5 ------------------------------ +// ------------------------------ V8 ------------------------------ TEST(type_prop, gather_nd_8_slices_from_4d_batch_dims0) { Shape params_shape{2, 3, 11, 12}; From cd85be1957b8a0fcb7d44f6ffb910609b0ae2a03 Mon Sep 17 00:00:00 2001 From: Chetverikov Date: Mon, 11 Oct 2021 17:51:47 +0300 Subject: [PATCH 13/29] Add limitation to doc --- docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md b/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md index 30173058227bfd..900802a2def239 100644 --- a/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md +++ b/docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md @@ -530,7 +530,7 @@ Standard ONNX\* operators: | GRU | No | | Gather | No | | GatherElements | Doesn't work with negative indices | -| GatherND | No | +| GatherND | Doesn't work with negative indices | | GatherTree | No | | Gemm | No | | GlobalAveragePool | No | From 6a0cbef61a79e755e85bce866b0ae9d53068f561 Mon Sep 17 00:00:00 2001 From: Chetverikov Date: Fri, 15 Oct 2021 13:37:56 +0300 Subject: [PATCH 14/29] Siplyfy check in shape inference --- model-optimizer/extensions/ops/gathernd.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/model-optimizer/extensions/ops/gathernd.py b/model-optimizer/extensions/ops/gathernd.py index 126960ae059261..0cdab7aed24e0a 100644 --- a/model-optimizer/extensions/ops/gathernd.py +++ b/model-optimizer/extensions/ops/gathernd.py @@ -74,10 +74,7 @@ def infer(node: Node): for i in range(batch_dims): batch_dims_size *= indices_shape[i] - if indices_shape[-1] == len(data_shape) - batch_dims: - output_shape = batch + list(indices_shape)[batch_dims:-1] - else: - output_shape = batch + list(indices_shape)[batch_dims:-1] + slice_shape + output_shape = batch + list(indices_shape)[batch_dims:-1] + slice_shape node.out_port(0).data.set_shape(output_shape) # compute output value if all input values are defined From afd02f8733af88cdfd167bba92fec59fd1277fbb Mon Sep 17 00:00:00 2001 From: Chetverikov Date: Fri, 15 Oct 2021 13:39:57 +0300 Subject: [PATCH 15/29] Add more test cases --- .../extensions/ops/gathernd_test.py | 117 +++++++++++++++--- 1 file changed, 102 insertions(+), 15 deletions(-) diff --git a/model-optimizer/unit_tests/extensions/ops/gathernd_test.py b/model-optimizer/unit_tests/extensions/ops/gathernd_test.py index f43386247c55be..013eb14a46b206 100644 --- a/model-optimizer/unit_tests/extensions/ops/gathernd_test.py +++ b/model-optimizer/unit_tests/extensions/ops/gathernd_test.py @@ -25,17 +25,21 @@ ('gathernd_node', 'output', {'out': 0})] # test data for partial infer: gather elements -inputs1 = {'data_data': {'shape': int64_array([10, 40]), 'value': None}, +inputs = {'data_data': {'shape': int64_array([10, 40]), 'value': None}, 'indices_data': {'shape': int64_array([3, 2]), 'value': None}} # test data for partial infer: gather slices -inputs2 = {'data_data': {'shape': int64_array([10, 40, 30]), 'value': None}, +inputs1 = {'data_data': {'shape': int64_array([10, 40, 30]), 'value': None}, 'indices_data': {'shape': int64_array([3, 2]), 'value': None}} # test data for partial infer: gather slices and batch_dims=2 -inputs3 = {'data_data': {'shape': int64_array([10, 40, 4, 9]), 'value': None}, +inputs2 = {'data_data': {'shape': int64_array([10, 40, 4, 9]), 'value': None}, 'indices_data': {'shape': int64_array([10, 40, 3, 5, 1]), 'value': None}} +# test data for partial infer: gather slices and batch_dims=3 and indices.shape[-1]=len(data.shape)-batch_dims +inputs3 = {'data_data': {'shape': int64_array([1, 64, 64, 320]), 'value': None}, + 'indices_data': {'shape': int64_array([1, 64, 64, 1, 1]), 'value': None}} + # test data for constant folding: gather elements, batch_dims = 0 inputs4 = {'data_data': {'shape': int64_array([2, 2]), 'value': int64_array([[1, 2], [3, 4]])}, @@ -100,11 +104,11 @@ [[1, 1], [2, 0], [2, 0]], [[0, 0], [3, 1], [3, 1]]]])}} output8 = int64_array([[[3, 8, 6], - [10, 12, 13], - [23, 24, 22]], - [[29, 28, 32], - [36, 37, 37], - [41, 48, 48]]]) + [10, 12, 13], + [23, 24, 22]], + [[29, 28, 32], + [36, 37, 37], + [41, 48, 48]]]) # test data for partial infer: gather slices and batch_dims=2 inputs9 = {'data_data': {'shape': shape_array([dynamic_dimension_value, 40, 4, 9]), 'value': None}, @@ -126,9 +130,10 @@ class TestGatherNDUpdate(unittest.TestCase): def setUp(self): nodes_attributes['gathernd_node']['batch_dims'] = 0 + nodes_attributes['gathernd_node']['version'] = 'opset8' def test_partial_infer_gather_element(self): - graph = build_graph(nodes_attributes, edges, inputs1) + graph = build_graph(nodes_attributes, edges, inputs) gathernd_node = Node(graph, 'gathernd_node') GatherND.infer(gathernd_node) @@ -142,7 +147,7 @@ def test_partial_infer_gather_element(self): 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) def test_partial_infer_gather_slice(self): - graph = build_graph(nodes_attributes, edges, inputs2) + graph = build_graph(nodes_attributes, edges, inputs1) gathernd_node = Node(graph, 'gathernd_node') GatherND.infer(gathernd_node) @@ -157,7 +162,7 @@ def test_partial_infer_gather_slice(self): def test_partial_infer_gather_slice_batch_dims2(self): nodes_attributes['gathernd_node']['batch_dims'] = 2 - graph = build_graph(nodes_attributes, edges, inputs3) + graph = build_graph(nodes_attributes, edges, inputs2) gathernd_node = Node(graph, 'gathernd_node') GatherND.infer(gathernd_node) @@ -170,6 +175,21 @@ def test_partial_infer_gather_slice_batch_dims2(self): self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + def test_partial_infer_gather_slice_batch_dims3(self): + nodes_attributes['gathernd_node']['batch_dims'] = 3 + graph = build_graph(nodes_attributes, edges, inputs3) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # prepare reference results + ref_output_shape = int64_array([1, 64, 64, 1]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), + 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + def test_partial_infer_gather_slice_batch_dims2_dynamic(self): nodes_attributes['gathernd_node']['batch_dims'] = 2 graph = build_graph(nodes_attributes, edges, inputs9) @@ -205,7 +225,7 @@ def test_infer5(self): res_output_value = graph.node['output']['value'] self.assertTrue(np.array_equal(output5, res_output_value), - 'values do not match expected: {} and given: {}'.format(output4, res_output_value)) + 'values do not match expected: {} and given: {}'.format(output5, res_output_value)) def test_infer6(self): nodes_attributes['gathernd_node']['batch_dims'] = 1 @@ -217,7 +237,20 @@ def test_infer6(self): res_output_value = graph.node['output']['value'] self.assertTrue(np.array_equal(output6, res_output_value), - 'values do not match expected: {} and given: {}'.format(output4, res_output_value)) + 'values do not match expected: {} and given: {}'.format(output6, res_output_value)) + + def test_infer6_opset_5(self): + nodes_attributes['gathernd_node']['batch_dims'] = 1 + nodes_attributes['gathernd_node']['version'] = 'opset5' + graph = build_graph(nodes_attributes, edges, inputs6) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + self.assertTrue(np.array_equal(output6, res_output_value), + 'values do not match expected: {} and given: {}'.format(output6, res_output_value)) def test_infer7(self): nodes_attributes['gathernd_node']['batch_dims'] = 2 @@ -229,7 +262,21 @@ def test_infer7(self): res_output_value = graph.node['output']['value'] self.assertTrue(np.array_equal(output7, res_output_value), - 'values do not match expected: {} and given: {}'.format(output4, res_output_value)) + 'values do not match expected: {} and given: {}'.format(output7, res_output_value)) + + def test_infer7_opset_5(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + nodes_attributes['gathernd_node']['version'] = 'opset5' + graph = build_graph(nodes_attributes, edges, inputs7) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + output = output7.reshape([6, 1]) + self.assertTrue(np.array_equal(output, res_output_value), + 'values do not match expected: {} and given: {}'.format(output, res_output_value)) def test_infer8(self): nodes_attributes['gathernd_node']['batch_dims'] = 2 @@ -241,7 +288,47 @@ def test_infer8(self): res_output_value = graph.node['output']['value'] self.assertTrue(np.array_equal(output8, res_output_value), - 'values do not match expected: {} and given: {}'.format(output4, res_output_value)) + 'values do not match expected: {} and given: {}'.format(output8, res_output_value)) + + def test_infer8_opset_5(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + nodes_attributes['gathernd_node']['version'] = 'opset5' + graph = build_graph(nodes_attributes, edges, inputs8) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + output = output8.reshape([6, 3]) + self.assertTrue(np.array_equal(output, res_output_value), + 'values do not match expected: {} and given: {}'.format(output, res_output_value)) + + def test_infer9(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + graph = build_graph(nodes_attributes, edges, inputs8) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + self.assertTrue(np.array_equal(output8, res_output_value), + 'values do not match expected: {} and given: {}'.format(output8, res_output_value)) + + def test_infer9_opset_5(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + nodes_attributes['gathernd_node']['version'] = 'opset5' + graph = build_graph(nodes_attributes, edges, inputs8) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + output = output8.reshape([6, 3]) + self.assertTrue(np.array_equal(output, res_output_value), + 'values do not match expected: {} and given: {}'.format(output, res_output_value)) def test_infer_invalid1(self): graph = build_graph(nodes_attributes, edges, inputs_inv1) From 29314ecc2e6d3a5c7be270acac7b6659ed735802 Mon Sep 17 00:00:00 2001 From: Chetverikov Date: Fri, 15 Oct 2021 13:56:06 +0300 Subject: [PATCH 16/29] Update shape inference function --- model-optimizer/extensions/ops/gathernd.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/model-optimizer/extensions/ops/gathernd.py b/model-optimizer/extensions/ops/gathernd.py index 0cdab7aed24e0a..9acd1a7a34db72 100644 --- a/model-optimizer/extensions/ops/gathernd.py +++ b/model-optimizer/extensions/ops/gathernd.py @@ -61,12 +61,21 @@ def infer(node: Node): if batch_dims > 0: if is_fully_defined(indices_shape[:batch_dims]): batch = indices_shape[:batch_dims].tolist() - if node['version'] == 'opset5': # Support old version of gather + if node['version'] == 'opset5': # Support old version of gatherND shape inference batch = [np.prod(data_shape[:batch_dims]).tolist()] else: - batch = [dynamic_dimension_value] + batch = [] + for ind in range(batch_dims): + if indices_shape[ind] != dynamic_dimension_value: + batch.append(indices_shape[ind]) + elif data_shape[ind] != dynamic_dimension_value: + batch.append(data_shape[ind]) + else: + batch.append(dynamic_dimension_value) + pass else: batch = [] + slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):]) batch_dims_size = 1 From 5819af2dd0d7d4e52f8b5cbb0eb9666dcc4802df Mon Sep 17 00:00:00 2001 From: Chetverikov Date: Fri, 15 Oct 2021 13:56:53 +0300 Subject: [PATCH 17/29] Add more test cases to cover all case with dynamic input shapes --- .../extensions/ops/gathernd_test.py | 42 ++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/model-optimizer/unit_tests/extensions/ops/gathernd_test.py b/model-optimizer/unit_tests/extensions/ops/gathernd_test.py index 013eb14a46b206..f0ae1d0c41de9a 100644 --- a/model-optimizer/unit_tests/extensions/ops/gathernd_test.py +++ b/model-optimizer/unit_tests/extensions/ops/gathernd_test.py @@ -114,6 +114,14 @@ inputs9 = {'data_data': {'shape': shape_array([dynamic_dimension_value, 40, 4, 9]), 'value': None}, 'indices_data': {'shape': shape_array([dynamic_dimension_value, 40, 3, 5, 1]), 'value': None}} +# test data for partial infer: gather slices and batch_dims=2 +inputs10 = {'data_data': {'shape': shape_array([40, dynamic_dimension_value, 4, 9]), 'value': None}, + 'indices_data': {'shape': shape_array([40, dynamic_dimension_value, 3, 5, 1]), 'value': None}} + +# test data for partial infer: gather slices and batch_dims=2 +inputs11 = {'data_data': {'shape': shape_array([dynamic_dimension_value, 40, 4, 9]), 'value': None}, + 'indices_data': {'shape': shape_array([40, dynamic_dimension_value, 3, 5, 1]), 'value': None}} + # invalid test case with incorrect rank for indices inputs_inv1 = {'data_data': {'shape': int64_array([10, 40]), 'value': None}, 'indices_data': {'shape': int64_array([5, 3, 4]), 'value': None}} @@ -190,14 +198,44 @@ def test_partial_infer_gather_slice_batch_dims3(self): self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) - def test_partial_infer_gather_slice_batch_dims2_dynamic(self): + def test_partial_infer_gather_slice_batch_dims2_dynamic1(self): nodes_attributes['gathernd_node']['batch_dims'] = 2 graph = build_graph(nodes_attributes, edges, inputs9) gathernd_node = Node(graph, 'gathernd_node') GatherND.infer(gathernd_node) # prepare reference results - ref_output_shape = shape_array([dynamic_dimension_value, 3, 5, 9]) + ref_output_shape = shape_array([dynamic_dimension_value, 40, 3, 5, 9]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape), + 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + def test_partial_infer_gather_slice_batch_dims2_dynamic2(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + graph = build_graph(nodes_attributes, edges, inputs10) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # prepare reference results + ref_output_shape = shape_array([40, dynamic_dimension_value, 3, 5, 9]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape), + 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + def test_partial_infer_gather_slice_batch_dims2_dynamic3(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + graph = build_graph(nodes_attributes, edges, inputs11) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # prepare reference results + ref_output_shape = shape_array([40, 40, 3, 5, 9]) # get the result res_output_shape = graph.node['output']['shape'] From c689aa6c32db5db45511082ca41a44f195ade00e Mon Sep 17 00:00:00 2001 From: Chetverikov Date: Mon, 18 Oct 2021 13:29:20 +0300 Subject: [PATCH 18/29] Update shape inference function --- model-optimizer/extensions/ops/gathernd.py | 36 +++++++++++++--------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/model-optimizer/extensions/ops/gathernd.py b/model-optimizer/extensions/ops/gathernd.py index 9acd1a7a34db72..fe0f903d1df106 100644 --- a/model-optimizer/extensions/ops/gathernd.py +++ b/model-optimizer/extensions/ops/gathernd.py @@ -56,24 +56,32 @@ def infer(node: Node): assert len(indices_shape) > 0, "Indices must not be a scalar" assert (batch_dims + indices_shape[-1]) <= len(data_shape), \ "Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions" + assert node['version'] in ['opset5', 'opset8'], 'Unsupported version of GatherND operation: {}, operation ' \ + 'name : {}'.format(node['version'], node.soft_get('name')) # compute output shape if batch_dims > 0: - if is_fully_defined(indices_shape[:batch_dims]): - batch = indices_shape[:batch_dims].tolist() - if node['version'] == 'opset5': # Support old version of gatherND shape inference + if node['version'] == 'opset5': # Support old version of gatherND shape inference + if is_fully_defined(data_shape[:batch_dims]): batch = [np.prod(data_shape[:batch_dims]).tolist()] - else: - batch = [] - for ind in range(batch_dims): - if indices_shape[ind] != dynamic_dimension_value: - batch.append(indices_shape[ind]) - elif data_shape[ind] != dynamic_dimension_value: - batch.append(data_shape[ind]) - else: - batch.append(dynamic_dimension_value) - pass - else: + else: + batch = [dynamic_dimension_value] + elif node['version'] == 'opset8': + if is_fully_defined(indices_shape[:batch_dims]): + batch = indices_shape[:batch_dims].tolist() + elif is_fully_defined(data_shape[:batch_dims]): + batch = data_shape[:batch_dims].tolist() + else: + batch = [] + for ind in range(batch_dims): + if indices_shape[ind] != dynamic_dimension_value: + batch.append(indices_shape[ind]) + elif data_shape[ind] != dynamic_dimension_value: + batch.append(data_shape[ind]) + else: + batch.append(dynamic_dimension_value) + + else: # if batch_dims == 0 batch = [] slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):]) From b95d1eba88a7c9b6b1b4924c557d5bd3b480bc32 Mon Sep 17 00:00:00 2001 From: Chetverikov Date: Mon, 18 Oct 2021 14:16:25 +0300 Subject: [PATCH 19/29] Refactor tests --- .../extensions/ops/gathernd_test.py | 180 ++++++++++++------ 1 file changed, 125 insertions(+), 55 deletions(-) diff --git a/model-optimizer/unit_tests/extensions/ops/gathernd_test.py b/model-optimizer/unit_tests/extensions/ops/gathernd_test.py index f0ae1d0c41de9a..8ccaad1a3abfb1 100644 --- a/model-optimizer/unit_tests/extensions/ops/gathernd_test.py +++ b/model-optimizer/unit_tests/extensions/ops/gathernd_test.py @@ -86,7 +86,7 @@ [[[0]], [[2]], [[2]]]])}} -output7 = int64_array([[[2], [5], [11]], [[13], [19], [23]]]) +output7 = int64_array([[2], [5], [11], [13], [19], [23]]) # test data for constant folding: gather elements, batch_dims = 2 inputs8 = {'data_data': {'shape': int64_array([2, 3, 4, 2]), @@ -103,12 +103,12 @@ [[[2, 0], [1, 1], [3, 1]], [[1, 1], [2, 0], [2, 0]], [[0, 0], [3, 1], [3, 1]]]])}} -output8 = int64_array([[[3, 8, 6], - [10, 12, 13], - [23, 24, 22]], - [[29, 28, 32], - [36, 37, 37], - [41, 48, 48]]]) +output8 = int64_array([[3, 8, 6], + [10, 12, 13], + [23, 24, 22], + [29, 28, 32], + [36, 37, 37], + [41, 48, 48]]) # test data for partial infer: gather slices and batch_dims=2 inputs9 = {'data_data': {'shape': shape_array([dynamic_dimension_value, 40, 4, 9]), 'value': None}, @@ -135,10 +135,10 @@ 'indices_data': {'shape': int64_array([10, 40, 4]), 'value': None}} -class TestGatherNDUpdate(unittest.TestCase): +class TestGatherND_5(unittest.TestCase): def setUp(self): nodes_attributes['gathernd_node']['batch_dims'] = 0 - nodes_attributes['gathernd_node']['version'] = 'opset8' + nodes_attributes['gathernd_node']['version'] = 'opset5' def test_partial_infer_gather_element(self): graph = build_graph(nodes_attributes, edges, inputs) @@ -175,7 +175,7 @@ def test_partial_infer_gather_slice_batch_dims2(self): GatherND.infer(gathernd_node) # prepare reference results - ref_output_shape = int64_array([10, 40, 3, 5, 9]) + ref_output_shape = int64_array([400, 3, 5, 9]) # get the result res_output_shape = graph.node['output']['shape'] @@ -190,7 +190,7 @@ def test_partial_infer_gather_slice_batch_dims3(self): GatherND.infer(gathernd_node) # prepare reference results - ref_output_shape = int64_array([1, 64, 64, 1]) + ref_output_shape = int64_array([4096, 1]) # get the result res_output_shape = graph.node['output']['shape'] @@ -205,7 +205,7 @@ def test_partial_infer_gather_slice_batch_dims2_dynamic1(self): GatherND.infer(gathernd_node) # prepare reference results - ref_output_shape = shape_array([dynamic_dimension_value, 40, 3, 5, 9]) + ref_output_shape = shape_array([dynamic_dimension_value, 3, 5, 9]) # get the result res_output_shape = graph.node['output']['shape'] @@ -220,7 +220,7 @@ def test_partial_infer_gather_slice_batch_dims2_dynamic2(self): GatherND.infer(gathernd_node) # prepare reference results - ref_output_shape = shape_array([40, dynamic_dimension_value, 3, 5, 9]) + ref_output_shape = shape_array([dynamic_dimension_value, 3, 5, 9]) # get the result res_output_shape = graph.node['output']['shape'] @@ -235,7 +235,7 @@ def test_partial_infer_gather_slice_batch_dims2_dynamic3(self): GatherND.infer(gathernd_node) # prepare reference results - ref_output_shape = shape_array([40, 40, 3, 5, 9]) + ref_output_shape = shape_array([dynamic_dimension_value, 3, 5, 9]) # get the result res_output_shape = graph.node['output']['shape'] @@ -277,19 +277,6 @@ def test_infer6(self): self.assertTrue(np.array_equal(output6, res_output_value), 'values do not match expected: {} and given: {}'.format(output6, res_output_value)) - def test_infer6_opset_5(self): - nodes_attributes['gathernd_node']['batch_dims'] = 1 - nodes_attributes['gathernd_node']['version'] = 'opset5' - graph = build_graph(nodes_attributes, edges, inputs6) - gathernd_node = Node(graph, 'gathernd_node') - GatherND.infer(gathernd_node) - - # get the result - res_output_value = graph.node['output']['value'] - - self.assertTrue(np.array_equal(output6, res_output_value), - 'values do not match expected: {} and given: {}'.format(output6, res_output_value)) - def test_infer7(self): nodes_attributes['gathernd_node']['batch_dims'] = 2 graph = build_graph(nodes_attributes, edges, inputs7) @@ -299,19 +286,6 @@ def test_infer7(self): # get the result res_output_value = graph.node['output']['value'] - self.assertTrue(np.array_equal(output7, res_output_value), - 'values do not match expected: {} and given: {}'.format(output7, res_output_value)) - - def test_infer7_opset_5(self): - nodes_attributes['gathernd_node']['batch_dims'] = 2 - nodes_attributes['gathernd_node']['version'] = 'opset5' - graph = build_graph(nodes_attributes, edges, inputs7) - gathernd_node = Node(graph, 'gathernd_node') - GatherND.infer(gathernd_node) - - # get the result - res_output_value = graph.node['output']['value'] - output = output7.reshape([6, 1]) self.assertTrue(np.array_equal(output, res_output_value), 'values do not match expected: {} and given: {}'.format(output, res_output_value)) @@ -328,20 +302,6 @@ def test_infer8(self): self.assertTrue(np.array_equal(output8, res_output_value), 'values do not match expected: {} and given: {}'.format(output8, res_output_value)) - def test_infer8_opset_5(self): - nodes_attributes['gathernd_node']['batch_dims'] = 2 - nodes_attributes['gathernd_node']['version'] = 'opset5' - graph = build_graph(nodes_attributes, edges, inputs8) - gathernd_node = Node(graph, 'gathernd_node') - GatherND.infer(gathernd_node) - - # get the result - res_output_value = graph.node['output']['value'] - - output = output8.reshape([6, 3]) - self.assertTrue(np.array_equal(output, res_output_value), - 'values do not match expected: {} and given: {}'.format(output, res_output_value)) - def test_infer9(self): nodes_attributes['gathernd_node']['batch_dims'] = 2 graph = build_graph(nodes_attributes, edges, inputs8) @@ -356,7 +316,6 @@ def test_infer9(self): def test_infer9_opset_5(self): nodes_attributes['gathernd_node']['batch_dims'] = 2 - nodes_attributes['gathernd_node']['version'] = 'opset5' graph = build_graph(nodes_attributes, edges, inputs8) gathernd_node = Node(graph, 'gathernd_node') GatherND.infer(gathernd_node) @@ -384,3 +343,114 @@ def test_infer_invalid3(self): graph = build_graph(nodes_attributes, edges, inputs_inv3) gathernd_node = Node(graph, 'gathernd_node') self.assertRaises(AssertionError, GatherND.infer, gathernd_node) + + + def test_partial_infer_gather_slice_batch_dims2_opset8(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + nodes_attributes['gathernd_node']['version'] = 'opset8' + graph = build_graph(nodes_attributes, edges, inputs2) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # prepare reference results + ref_output_shape = int64_array([10, 40, 3, 5, 9]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), + 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + def test_partial_infer_gather_slice_batch_dims3_opset8(self): + nodes_attributes['gathernd_node']['batch_dims'] = 3 + nodes_attributes['gathernd_node']['version'] = 'opset8' + graph = build_graph(nodes_attributes, edges, inputs3) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # prepare reference results + ref_output_shape = int64_array([1, 64, 64, 1]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), + 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + def test_partial_infer_gather_slice_batch_dims2_dynamic1_opset8(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + nodes_attributes['gathernd_node']['version'] = 'opset8' + graph = build_graph(nodes_attributes, edges, inputs9) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # prepare reference results + ref_output_shape = shape_array([dynamic_dimension_value, 40, 3, 5, 9]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape), + 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + def test_partial_infer_gather_slice_batch_dims2_dynamic2_opset8(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + nodes_attributes['gathernd_node']['version'] = 'opset8' + graph = build_graph(nodes_attributes, edges, inputs10) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # prepare reference results + ref_output_shape = shape_array([40, dynamic_dimension_value, 3, 5, 9]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape), + 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + def test_partial_infer_gather_slice_batch_dims2_dynamic3_opset8(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + nodes_attributes['gathernd_node']['version'] = 'opset8' + graph = build_graph(nodes_attributes, edges, inputs11) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # prepare reference results + ref_output_shape = shape_array([40, 40, 3, 5, 9]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape), + 'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + def test_infer7_opset8(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + nodes_attributes['gathernd_node']['version'] = 'opset8' + graph = build_graph(nodes_attributes, edges, inputs7) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + output = output7.reshape([2, 3, 1]) + + self.assertTrue(np.array_equal(output, res_output_value), + 'values do not match expected: {} and given: {}'.format(output, res_output_value)) + + def test_infer8_opset8(self): + nodes_attributes['gathernd_node']['batch_dims'] = 2 + nodes_attributes['gathernd_node']['version'] = 'opset8' + graph = build_graph(nodes_attributes, edges, inputs8) + gathernd_node = Node(graph, 'gathernd_node') + GatherND.infer(gathernd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + output = output8.reshape([2, 3, 3]) + + self.assertTrue(np.array_equal(output, res_output_value), + 'values do not match expected: {} and given: {}'.format(output, res_output_value)) From e8039655004e12e6f6a12b8ac6b88d24ab718681 Mon Sep 17 00:00:00 2001 From: achetver Date: Fri, 22 Oct 2021 10:57:59 +0300 Subject: [PATCH 20/29] Add visitor tests for gatherND_8 operation --- ngraph/test/visitors/op/gather_nd.cpp | 41 +++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 ngraph/test/visitors/op/gather_nd.cpp diff --git a/ngraph/test/visitors/op/gather_nd.cpp b/ngraph/test/visitors/op/gather_nd.cpp new file mode 100644 index 00000000000000..c4d44ef34fc945 --- /dev/null +++ b/ngraph/test/visitors/op/gather_nd.cpp @@ -0,0 +1,41 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "ngraph/opsets/opset1.hpp" +#include "ngraph/opsets/opset5.hpp" +#include "ngraph/opsets/opset8.hpp" +#include "util/visitor.hpp" + +using namespace std; +using namespace ngraph; +using ngraph::test::NodeBuilder; +using ngraph::test::ValueMap; + +TEST(attributes, gather_nd_v5_op) { + NodeBuilder::get_ops().register_factory(); + auto data = make_shared(element::i32, Shape{2, 3, 4}); + auto indices = make_shared(element::i32, Shape{2}); + int64_t batch_dims = 1; + + auto gather = make_shared(data, indices, batch_dims); + NodeBuilder builder(gather); + auto g_gather = ov::as_type_ptr(builder.create()); + + EXPECT_EQ(g_gather->get_batch_dims(), gather->get_batch_dims()); +} + +TEST(attributes, gather_v8_op) { + NodeBuilder::get_ops().register_factory(); + auto data = make_shared(element::i32, Shape{2, 3, 4}); + auto indices = make_shared(element::i32, Shape{2}); + int64_t batch_dims = 1; + + auto gather = make_shared(data, indices, batch_dims); + NodeBuilder builder(gather); + auto g_gather = ov::as_type_ptr(builder.create()); + + EXPECT_EQ(g_gather->get_batch_dims(), gather->get_batch_dims()); +} From f0f834c4de16e96af97a552689fa782c8a78bdfd Mon Sep 17 00:00:00 2001 From: achetver Date: Fri, 22 Oct 2021 11:41:22 +0300 Subject: [PATCH 21/29] Correct comment --- ngraph/core/include/openvino/op/util/gather_nd_base.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp index 44b923954b64d4..cd23561c3e4da6 100644 --- a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp +++ b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp @@ -21,7 +21,7 @@ class OPENVINO_API GatherNDBase : public Op { /// \param data Node producing data that are gathered /// \param indices Node producing indices by which the operation gathers elements /// or slices from data - /// \param batch_dims Specifies a number of batch dimensions + /// \param batch_dims Specifies a leading number of dimensions representing the batches GatherNDBase(const Output& data, const Output& indices, const size_t batch_dims = 0); size_t get_batch_dims() const { From efe6c5b32bb4f5a1dbc0bd92aa9aba31252cec60 Mon Sep 17 00:00:00 2001 From: achetver Date: Mon, 25 Oct 2021 11:40:40 +0300 Subject: [PATCH 22/29] Add additional check is shape inference function --- model-optimizer/extensions/ops/gathernd.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/model-optimizer/extensions/ops/gathernd.py b/model-optimizer/extensions/ops/gathernd.py index fe0f903d1df106..d3eae18f1e3117 100644 --- a/model-optimizer/extensions/ops/gathernd.py +++ b/model-optimizer/extensions/ops/gathernd.py @@ -60,6 +60,7 @@ def infer(node: Node): 'name : {}'.format(node['version'], node.soft_get('name')) # compute output shape + batch = [] if batch_dims > 0: if node['version'] == 'opset5': # Support old version of gatherND shape inference if is_fully_defined(data_shape[:batch_dims]): @@ -67,6 +68,9 @@ def infer(node: Node): else: batch = [dynamic_dimension_value] elif node['version'] == 'opset8': + for dim in range(batch_dims): + assert compatible_dims(indices_shape[dim], data_shape[dim]),\ + "Batch dimensions in data.shape and indices.shape must be compatible" if is_fully_defined(indices_shape[:batch_dims]): batch = indices_shape[:batch_dims].tolist() elif is_fully_defined(data_shape[:batch_dims]): @@ -81,9 +85,6 @@ def infer(node: Node): else: batch.append(dynamic_dimension_value) - else: # if batch_dims == 0 - batch = [] - slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):]) batch_dims_size = 1 From 3cd22edf883fa58ecb0eb0f3c3f9958faa8a1118 Mon Sep 17 00:00:00 2001 From: achetver Date: Mon, 25 Oct 2021 17:19:10 +0300 Subject: [PATCH 23/29] Update shape inference implementation for gathernd operartion --- ngraph/core/include/openvino/op/gather_nd.hpp | 2 +- .../openvino/op/util/gather_nd_base.hpp | 2 + ngraph/core/src/op/gather_nd.cpp | 185 +++--------------- ngraph/core/src/op/util/gather_nd_base.cpp | 91 +++++++++ 4 files changed, 116 insertions(+), 164 deletions(-) diff --git a/ngraph/core/include/openvino/op/gather_nd.hpp b/ngraph/core/include/openvino/op/gather_nd.hpp index 19f0207d616348..3a46f12243bc04 100644 --- a/ngraph/core/include/openvino/op/gather_nd.hpp +++ b/ngraph/core/include/openvino/op/gather_nd.hpp @@ -35,7 +35,7 @@ namespace v8 { /// class OPENVINO_API GatherND : public op::util::GatherNDBase { public: - OPENVINO_OP("GatherND", "opset8", op::util::GatherNDBase, 8); + OPENVINO_OP("GatherND", "opset8", op::util::GatherNDBase); BWDCMP_RTTI_DECLARATION; GatherND() = default; diff --git a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp index cd23561c3e4da6..79eea70da66971 100644 --- a/ngraph/core/include/openvino/op/util/gather_nd_base.hpp +++ b/ngraph/core/include/openvino/op/util/gather_nd_base.hpp @@ -28,6 +28,8 @@ class OPENVINO_API GatherNDBase : public Op { return m_batch_dims; } + void validate_inputs_and_infer_shape(); + protected: size_t m_batch_dims = 0; }; diff --git a/ngraph/core/src/op/gather_nd.cpp b/ngraph/core/src/op/gather_nd.cpp index 90dd5be24a20b6..4b7407f249bb9f 100644 --- a/ngraph/core/src/op/gather_nd.cpp +++ b/ngraph/core/src/op/gather_nd.cpp @@ -21,92 +21,37 @@ op::v5::GatherND::GatherND(const Output& data, const Output& indices void op::v5::GatherND::validate_and_infer_types() { NGRAPH_OP_SCOPE(v5_GatherND_validate_and_infer_types); - // check types of input tensors - const auto& data_type = get_input_element_type(0); - const auto& indices_type = get_input_element_type(1); + validate_inputs_and_infer_shape(); - NODE_VALIDATION_CHECK(this, - indices_type.is_integral_number(), - "The indices type is expected to be an integer type. Got: ", - indices_type); + // If we have m_batch_dims > 1 we need to fuse batch dimensions of output + if (m_batch_dims > 1) { + const auto& output_pshape = get_output_partial_shape(0); + const auto& data_type = get_input_element_type(0); - // check ranks of input tensors - const auto& data_pshape = get_input_partial_shape(0); - const auto& indices_pshape = get_input_partial_shape(1); - - if (data_pshape.rank().is_static()) { - NODE_VALIDATION_CHECK(this, data_pshape.rank().get_length() > 0, "Data rank must be at least 1."); - - NODE_VALIDATION_CHECK(this, - data_pshape.rank().get_length() > static_cast(m_batch_dims), - "Number of batch dimensions must not exceed a rank of data."); - } - - if (indices_pshape.rank().is_static()) { - NODE_VALIDATION_CHECK(this, indices_pshape.rank().get_length() > 0, "Indices rank must be at least 1."); - - NODE_VALIDATION_CHECK(this, - indices_pshape.rank().get_length() > static_cast(m_batch_dims), - "Number of batch dimensions must not exceed a rank of indices."); - } - - if (data_pshape.rank().is_static() && indices_pshape.rank().is_static()) { - // check that batch dimensions of data and indices are the same - for (size_t batch_dim = 0; batch_dim < m_batch_dims; batch_dim++) { - if (data_pshape[batch_dim].is_static() && indices_pshape[batch_dim].is_static()) { - NODE_VALIDATION_CHECK(this, - data_pshape[batch_dim].get_length() == indices_pshape[batch_dim].get_length(), - "Batch dimensions of data and indices must be the same."); - } - } - - if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { - NODE_VALIDATION_CHECK( - this, - static_cast(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() + - m_batch_dims) <= data_pshape.rank().get_length(), - "Length of a tuple with indices must not exceed a rank of data tensor " - "excluding " - "batch dimensions."); - } - } - - // set output shape - set_output_size(1); - if (data_pshape.rank().is_static() && indices_pshape.rank().is_static() && - indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { - auto indices_tuple_length = indices_pshape[indices_pshape.rank().get_length() - 1].get_length(); - int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - m_batch_dims; - int64_t output_indices_length = indices_pshape.rank().get_length() - m_batch_dims - 1; - auto output_rank = output_indices_length + slice_length; - size_t delta_output_rank = 0; - if (m_batch_dims > 0) { - delta_output_rank = 1; - } - std::vector output_shape(output_rank + delta_output_rank); - if (m_batch_dims > 0) { + if (output_pshape.rank().is_static()) { + const auto& out_size = output_pshape.size(); + std::vector output_shape(out_size - m_batch_dims + 1); output_shape[0] = 1; for (size_t dim = 0; dim < m_batch_dims; dim++) { - if (data_pshape[dim].is_static()) { - output_shape[0] *= data_pshape[dim].get_length(); - } else if (indices_pshape[dim].is_static()) { - output_shape[0] *= indices_pshape[dim].get_length(); + if (output_pshape[dim].is_static()) { + output_shape[0] *= output_pshape[dim].get_length(); } else { output_shape[0] = Dimension::dynamic(); break; } } + size_t ind = 1; + for (size_t dim = m_batch_dims; dim < out_size; dim++) { + if (output_pshape[dim].is_static()) { + output_shape[ind] = output_pshape[dim].get_length(); + } else { + output_shape[ind] = Dimension::dynamic(); + } + ind++; + } + + set_output_type(0, data_type, ov::PartialShape(output_shape)); } - for (int64_t dim = 0; dim < output_indices_length; dim++) { - output_shape[dim + delta_output_rank] = indices_pshape[dim + m_batch_dims]; - } - for (int64_t dim = 0; dim < slice_length; dim++) { - output_shape[output_indices_length + dim + delta_output_rank] = - data_pshape[m_batch_dims + indices_tuple_length + dim]; - } - set_output_type(0, data_type, ov::PartialShape(output_shape)); - } else { - set_output_type(0, data_type, ov::PartialShape::dynamic()); } } @@ -132,93 +77,7 @@ op::v8::GatherND::GatherND(const Output& data, const Output& indices void op::v8::GatherND::validate_and_infer_types() { NGRAPH_OP_SCOPE(v8_GatherND_validate_and_infer_types); - // check types of input tensors - const auto& data_type = get_input_element_type(0); - const auto& indices_type = get_input_element_type(1); - - NODE_VALIDATION_CHECK(this, - indices_type.is_integral_number(), - "The indices type is expected to be an integer type. Got: ", - indices_type); - - // check ranks of input tensors - const auto& data_pshape = get_input_partial_shape(0); - const auto& indices_pshape = get_input_partial_shape(1); - - if (data_pshape.rank().is_static()) { - NODE_VALIDATION_CHECK(this, data_pshape.rank().get_length() > 0, "Data rank must be at least 1."); - - NODE_VALIDATION_CHECK(this, - data_pshape.rank().get_length() > static_cast(m_batch_dims), - "Number of batch dimensions must not exceed a rank of data."); - } - - if (indices_pshape.rank().is_static()) { - NODE_VALIDATION_CHECK(this, indices_pshape.rank().get_length() > 0, "Indices rank must be at least 1."); - - NODE_VALIDATION_CHECK(this, - indices_pshape.rank().get_length() > static_cast(m_batch_dims), - "Number of batch dimensions must not exceed a rank of indices."); - } - - if (data_pshape.rank().is_static() && indices_pshape.rank().is_static()) { - // check that batch dimensions of data and indices are the same - for (size_t batch_dim = 0; batch_dim < m_batch_dims; batch_dim++) { - if (data_pshape[batch_dim].is_static() && indices_pshape[batch_dim].is_static()) { - NODE_VALIDATION_CHECK(this, - data_pshape[batch_dim].get_length() == indices_pshape[batch_dim].get_length(), - "Batch dimensions of data and indices must be the same."); - } - } - - if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { - NODE_VALIDATION_CHECK( - this, - static_cast(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() + - m_batch_dims) <= data_pshape.rank().get_length(), - "Length of a tuple with indices must not exceed a rank of data tensor " - "excluding " - "batch dimensions."); - } - } - - // set output shape - set_output_size(1); - if (data_pshape.rank().is_static() && indices_pshape.rank().is_static() && - indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { - auto indices_tuple_length = indices_pshape[indices_pshape.rank().get_length() - 1].get_length(); - int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - m_batch_dims; - int64_t output_indices_length = indices_pshape.rank().get_length() - m_batch_dims - 1; - auto output_rank = output_indices_length + slice_length; - size_t delta_output_rank = 0; - if (m_batch_dims > 0) { - delta_output_rank = m_batch_dims; - } - std::vector output_shape(output_rank + delta_output_rank); - if (m_batch_dims > 0) { - for (size_t dim = 0; dim < m_batch_dims; dim++) { - output_shape[dim] = 1; - if (data_pshape[dim].is_static()) { - output_shape[dim] = data_pshape[dim].get_length(); - } else if (indices_pshape[dim].is_static()) { - output_shape[dim] = indices_pshape[dim].get_length(); - } else { - output_shape[dim] = Dimension::dynamic(); - break; - } - } - } - for (int64_t dim = 0; dim < output_indices_length; dim++) { - output_shape[dim + delta_output_rank] = indices_pshape[dim + m_batch_dims]; - } - for (int64_t dim = 0; dim < slice_length; dim++) { - output_shape[output_indices_length + dim + delta_output_rank] = - data_pshape[m_batch_dims + indices_tuple_length + dim]; - } - set_output_type(0, data_type, ov::PartialShape(output_shape)); - } else { - set_output_type(0, data_type, ov::PartialShape::dynamic()); - } + validate_inputs_and_infer_shape(); } bool op::v8::GatherND::visit_attributes(AttributeVisitor& visitor) { diff --git a/ngraph/core/src/op/util/gather_nd_base.cpp b/ngraph/core/src/op/util/gather_nd_base.cpp index f00d0510d1acbc..86a4749264065d 100644 --- a/ngraph/core/src/op/util/gather_nd_base.cpp +++ b/ngraph/core/src/op/util/gather_nd_base.cpp @@ -12,6 +12,7 @@ #include "ngraph/runtime/host_tensor.hpp" #include "ngraph/shape.hpp" + using namespace std; BWDCMP_RTTI_DEFINITION(ov::op::util::GatherNDBase); @@ -21,3 +22,93 @@ ov::op::util::GatherNDBase::GatherNDBase(const Output& data, const Output< m_batch_dims(batch_dims) { constructor_validate_and_infer_types(); } + +void ov::op::util::GatherNDBase::validate_inputs_and_infer_shape() { + // check types of input tensors + const auto& data_type = get_input_element_type(0); + const auto& indices_type = get_input_element_type(1); + + NODE_VALIDATION_CHECK(this, + indices_type.is_integral_number(), + "The indices type is expected to be an integer type. Got: ", + indices_type); + + // check ranks of input tensors + const auto& data_pshape = get_input_partial_shape(0); + const auto& indices_pshape = get_input_partial_shape(1); + + if (data_pshape.rank().is_static()) { + NODE_VALIDATION_CHECK(this, data_pshape.rank().get_length() > 0, "Data rank must be at least 1."); + + NODE_VALIDATION_CHECK(this, + data_pshape.rank().get_length() > static_cast(m_batch_dims), + "Number of batch dimensions must not exceed a rank of data."); + } + + if (indices_pshape.rank().is_static()) { + NODE_VALIDATION_CHECK(this, indices_pshape.rank().get_length() > 0, "Indices rank must be at least 1."); + + NODE_VALIDATION_CHECK(this, + indices_pshape.rank().get_length() > static_cast(m_batch_dims), + "Number of batch dimensions must not exceed a rank of indices."); + } + + if (data_pshape.rank().is_static() && indices_pshape.rank().is_static()) { + // check that batch dimensions of data and indices are the same + for (size_t batch_dim = 0; batch_dim < m_batch_dims; batch_dim++) { + if (data_pshape[batch_dim].is_static() && indices_pshape[batch_dim].is_static()) { + NODE_VALIDATION_CHECK(this, + data_pshape[batch_dim].get_length() == indices_pshape[batch_dim].get_length(), + "Batch dimensions of data and indices must be the same."); + } + } + + if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { + NODE_VALIDATION_CHECK( + this, + static_cast(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() + + m_batch_dims) <= data_pshape.rank().get_length(), + "Length of a tuple with indices must not exceed a rank of data tensor " + "excluding " + "batch dimensions."); + } + } + + // set output shape + set_output_size(1); + if (data_pshape.rank().is_static() && indices_pshape.rank().is_static() && + indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { + auto indices_tuple_length = indices_pshape[indices_pshape.rank().get_length() - 1].get_length(); + int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - m_batch_dims; + int64_t output_indices_length = indices_pshape.rank().get_length() - m_batch_dims - 1; + auto output_rank = output_indices_length + slice_length; + size_t delta_output_rank = 0; + if (m_batch_dims > 0) { + delta_output_rank = m_batch_dims; + } + std::vector output_shape(output_rank + delta_output_rank); + if (m_batch_dims > 0) { + for (size_t dim = 0; dim < m_batch_dims; dim++) { + output_shape[dim] = 1; + if (data_pshape[dim].is_static()) { + output_shape[dim] = data_pshape[dim].get_length(); + } else if (indices_pshape[dim].is_static()) { + output_shape[dim] = indices_pshape[dim].get_length(); + } else { + output_shape[dim] = Dimension::dynamic(); + break; + } + } + } + for (int64_t dim = 0; dim < output_indices_length; dim++) { + output_shape[dim + delta_output_rank] = indices_pshape[dim + m_batch_dims]; + } + for (int64_t dim = 0; dim < slice_length; dim++) { + output_shape[output_indices_length + dim + delta_output_rank] = + data_pshape[m_batch_dims + indices_tuple_length + dim]; + } + set_output_type(0, data_type, ov::PartialShape(output_shape)); + } else { + set_output_type(0, data_type, ov::PartialShape::dynamic()); + } +} From 1e194487c88eadef561812f1e4c7b96826bb54f0 Mon Sep 17 00:00:00 2001 From: achetver Date: Tue, 26 Oct 2021 01:22:00 +0300 Subject: [PATCH 24/29] Fix codestyle --- ngraph/core/src/op/util/gather_nd_base.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/ngraph/core/src/op/util/gather_nd_base.cpp b/ngraph/core/src/op/util/gather_nd_base.cpp index 86a4749264065d..926bb6f272e478 100644 --- a/ngraph/core/src/op/util/gather_nd_base.cpp +++ b/ngraph/core/src/op/util/gather_nd_base.cpp @@ -12,7 +12,6 @@ #include "ngraph/runtime/host_tensor.hpp" #include "ngraph/shape.hpp" - using namespace std; BWDCMP_RTTI_DEFINITION(ov::op::util::GatherNDBase); @@ -65,12 +64,12 @@ void ov::op::util::GatherNDBase::validate_inputs_and_infer_shape() { if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) { NODE_VALIDATION_CHECK( - this, - static_cast(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() + - m_batch_dims) <= data_pshape.rank().get_length(), - "Length of a tuple with indices must not exceed a rank of data tensor " - "excluding " - "batch dimensions."); + this, + static_cast(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() + + m_batch_dims) <= data_pshape.rank().get_length(), + "Length of a tuple with indices must not exceed a rank of data tensor " + "excluding " + "batch dimensions."); } } @@ -105,7 +104,7 @@ void ov::op::util::GatherNDBase::validate_inputs_and_infer_shape() { } for (int64_t dim = 0; dim < slice_length; dim++) { output_shape[output_indices_length + dim + delta_output_rank] = - data_pshape[m_batch_dims + indices_tuple_length + dim]; + data_pshape[m_batch_dims + indices_tuple_length + dim]; } set_output_type(0, data_type, ov::PartialShape(output_shape)); } else { From 96067717de3765f68cc525496250000104fe5d41 Mon Sep 17 00:00:00 2001 From: Chetverikov Date: Tue, 26 Oct 2021 15:39:06 +0300 Subject: [PATCH 25/29] Remove restriction for data is fully defined --- model-optimizer/extensions/ops/gathernd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model-optimizer/extensions/ops/gathernd.py b/model-optimizer/extensions/ops/gathernd.py index d3eae18f1e3117..e7b08193d31c9d 100644 --- a/model-optimizer/extensions/ops/gathernd.py +++ b/model-optimizer/extensions/ops/gathernd.py @@ -95,8 +95,8 @@ def infer(node: Node): output_shape = batch + list(indices_shape)[batch_dims:-1] + slice_shape node.out_port(0).data.set_shape(output_shape) - # compute output value if all input values are defined - if is_fully_defined(indices_value) and is_fully_defined(data_value): + # compute output value if all input indices are defined + if is_fully_defined(indices_value): output_data = [] reshaped_indices = indices_value.reshape(batch_dims_size, -1, indices_shape[-1]) From b6d612e15b65e823fe331f0fbffb32f69baeac11 Mon Sep 17 00:00:00 2001 From: Chetverikov Date: Wed, 27 Oct 2021 09:56:34 +0300 Subject: [PATCH 26/29] Update shape inference functon --- model-optimizer/extensions/ops/gathernd.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/model-optimizer/extensions/ops/gathernd.py b/model-optimizer/extensions/ops/gathernd.py index e7b08193d31c9d..4255335e2c0f95 100644 --- a/model-optimizer/extensions/ops/gathernd.py +++ b/model-optimizer/extensions/ops/gathernd.py @@ -76,7 +76,6 @@ def infer(node: Node): elif is_fully_defined(data_shape[:batch_dims]): batch = data_shape[:batch_dims].tolist() else: - batch = [] for ind in range(batch_dims): if indices_shape[ind] != dynamic_dimension_value: batch.append(indices_shape[ind]) @@ -87,16 +86,16 @@ def infer(node: Node): slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):]) - batch_dims_size = 1 - - for i in range(batch_dims): - batch_dims_size *= indices_shape[i] - output_shape = batch + list(indices_shape)[batch_dims:-1] + slice_shape node.out_port(0).data.set_shape(output_shape) # compute output value if all input indices are defined if is_fully_defined(indices_value): + batch_dims_size = 1 + + for i in range(batch_dims): + batch_dims_size *= indices_shape[i] + output_data = [] reshaped_indices = indices_value.reshape(batch_dims_size, -1, indices_shape[-1]) From e7495b4c7285829a598167fcb9938ae711a91e93 Mon Sep 17 00:00:00 2001 From: achetver Date: Thu, 28 Oct 2021 11:23:13 +0300 Subject: [PATCH 27/29] Fix missed check for nonetype --- model-optimizer/extensions/ops/gathernd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-optimizer/extensions/ops/gathernd.py b/model-optimizer/extensions/ops/gathernd.py index 4255335e2c0f95..01e6e2f36c4794 100644 --- a/model-optimizer/extensions/ops/gathernd.py +++ b/model-optimizer/extensions/ops/gathernd.py @@ -90,7 +90,7 @@ def infer(node: Node): node.out_port(0).data.set_shape(output_shape) # compute output value if all input indices are defined - if is_fully_defined(indices_value): + if is_fully_defined(indices_value) and data_value is not None: batch_dims_size = 1 for i in range(batch_dims): From 5b51350398e91d7fcab94493be3ecd6fed2aea85 Mon Sep 17 00:00:00 2001 From: achetver Date: Thu, 28 Oct 2021 11:33:34 +0300 Subject: [PATCH 28/29] Remove redundant checks for batch_dims --- ngraph/core/src/op/util/gather_nd_base.cpp | 24 +++++++++------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/ngraph/core/src/op/util/gather_nd_base.cpp b/ngraph/core/src/op/util/gather_nd_base.cpp index 926bb6f272e478..f0bed099b3662e 100644 --- a/ngraph/core/src/op/util/gather_nd_base.cpp +++ b/ngraph/core/src/op/util/gather_nd_base.cpp @@ -82,21 +82,17 @@ void ov::op::util::GatherNDBase::validate_inputs_and_infer_shape() { int64_t output_indices_length = indices_pshape.rank().get_length() - m_batch_dims - 1; auto output_rank = output_indices_length + slice_length; size_t delta_output_rank = 0; - if (m_batch_dims > 0) { - delta_output_rank = m_batch_dims; - } + delta_output_rank = m_batch_dims; std::vector output_shape(output_rank + delta_output_rank); - if (m_batch_dims > 0) { - for (size_t dim = 0; dim < m_batch_dims; dim++) { - output_shape[dim] = 1; - if (data_pshape[dim].is_static()) { - output_shape[dim] = data_pshape[dim].get_length(); - } else if (indices_pshape[dim].is_static()) { - output_shape[dim] = indices_pshape[dim].get_length(); - } else { - output_shape[dim] = Dimension::dynamic(); - break; - } + for (size_t dim = 0; dim < m_batch_dims; dim++) { + output_shape[dim] = 1; + if (data_pshape[dim].is_static()) { + output_shape[dim] = data_pshape[dim].get_length(); + } else if (indices_pshape[dim].is_static()) { + output_shape[dim] = indices_pshape[dim].get_length(); + } else { + output_shape[dim] = Dimension::dynamic(); + break; } } for (int64_t dim = 0; dim < output_indices_length; dim++) { From 795e31b6ed6c4103bcef7e5d0357b6edb5fb9acb Mon Sep 17 00:00:00 2001 From: Chetverikov Date: Wed, 3 Nov 2021 14:39:47 +0300 Subject: [PATCH 29/29] Fix codestyle --- ngraph/test/visitors/op/gather_nd.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ngraph/test/visitors/op/gather_nd.cpp b/ngraph/test/visitors/op/gather_nd.cpp index 64bfee5da73e18..5bfba312e0c78e 100644 --- a/ngraph/test/visitors/op/gather_nd.cpp +++ b/ngraph/test/visitors/op/gather_nd.cpp @@ -27,7 +27,6 @@ TEST(attributes, gather_nd_v5_op) { EXPECT_EQ(g_G->get_batch_dims(), G->get_batch_dims()); } - TEST(attributes, gather_nd_v8_op) { NodeBuilder::get_ops().register_factory(); int batch_dims = 1;