Skip to content

Commit

Permalink
[ref]: Added SearchSorted ref impl (#26958)
Browse files Browse the repository at this point in the history
### Details:
 - Added ref implementation of search sorted op

### Tickets:
 - *CVS-154061*

Depends on: #26904

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
Co-authored-by: Pawel Raasz <pawel.raasz@intel.com>
Co-authored-by: Andrey Babushkin <andrey.babushkin@intel.com>
Co-authored-by: Alicja Miloszewska <alicja.miloszewska@intel.com>
Co-authored-by: Bogdan Pereanu <bogdan.pereanu@intel.com>
Co-authored-by: Karol Blaszczak <karol.blaszczak@intel.com>
Co-authored-by: Tatiana Savina <tatiana.savina@intel.com>
Co-authored-by: Anastasiya(Asya) Pronina <anastasiya.pronina@intel.com>
Co-authored-by: Dmitry Matveev <dmitry.matveev@intel.com>
Co-authored-by: Andrei Beleiu <andrei-marin.beleiu@intel.com>
Co-authored-by: Andrew Kwangwoong Park <andrew.park@intel.com>
Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
Co-authored-by: Pavel Durandin <pavel.durandin@intel.com>
Co-authored-by: Alexey Smirnov <alexey.smirnov@intel.com>
Co-authored-by: Hubert Błaszczyk <56601011+hub-bla@users.noreply.github.com>
Co-authored-by: Vladimir Paramuzov <vladimir.paramuzov@intel.com>
Co-authored-by: Sergey Shlyapnikov <sergey.shlyapnikov@intel.com>
Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
Co-authored-by: Andrzej Kopytko <andrzejx.kopytko@intel.com>
Co-authored-by: Sebastian Golebiewski <sebastianx.golebiewski@intel.com>
Co-authored-by: Alina Kladieva <alina.kladieva@intel.com>
Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
  • Loading branch information
1 parent 79e229b commit b212815
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/core/include/openvino/opsets/opset15_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ _OPENVINO_OP_REG(StringTensorPack, ov::op::v15)
_OPENVINO_OP_REG(BitwiseLeftShift, ov::op::v15)
_OPENVINO_OP_REG(BitwiseRightShift, ov::op::v15)
_OPENVINO_OP_REG(SliceScatter, ov::op::v15)
_OPENVINO_OP_REG(SearchSorted, ov::op::v15)
55 changes: 55 additions & 0 deletions src/core/reference/include/openvino/reference/search_sorted.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/shape.hpp"
#include "openvino/reference/utils/coordinate_index.hpp"
#include "openvino/reference/utils/coordinate_transform.hpp"

namespace ov {
namespace reference {
template <typename T, typename TOut = int64_t>
void search_sorted(const T* sorted,
const T* values,
TOut* out,
const Shape& sorted_shape,
const Shape& values_shape,
bool right_mode) {
const CoordinateTransformBasic values_transform{values_shape};

std::function<const T*(const T*, const T*, T)> compare_func = nullptr;
if (right_mode) {
compare_func = [](const T* begin, const T* end, T value) {
return std::lower_bound(begin, end, value, std::less_equal<T>());
};
} else {
compare_func = [](const T* begin, const T* end, T value) {
return std::lower_bound(begin, end, value, std::less<T>());
};
}

for (const Coordinate& values_coord : values_transform) {
const auto values_index = coordinate_index(values_coord, values_shape);
const T value = values[values_index];

Coordinate sorted_coord_begin = values_coord;
sorted_coord_begin.back() = 0;

Coordinate sorted_coord_last = values_coord;
sorted_coord_last.back() = sorted_shape.back();

const auto sorted_index_begin = coordinate_index(sorted_coord_begin, sorted_shape);
const auto sorted_index_last = coordinate_index(sorted_coord_last, sorted_shape);

const T* idx_ptr = compare_func(sorted + sorted_index_begin, sorted + sorted_index_last, value);

const ptrdiff_t sorted_index = (idx_ptr - sorted) - sorted_index_begin;

out[values_index] = static_cast<TOut>(sorted_index);
}
}

} // namespace reference
} // namespace ov
2 changes: 1 addition & 1 deletion src/core/tests/opset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ INSTANTIATE_TEST_SUITE_P(opset,
OpsetTestParams{ov::get_opset12, 178},
OpsetTestParams{ov::get_opset13, 186},
OpsetTestParams{ov::get_opset14, 188},
OpsetTestParams{ov::get_opset15, 14}),
OpsetTestParams{ov::get_opset15, 15}),
OpsetTestNameGenerator{});

class MyOpOld : public ov::op::Op {
Expand Down
4 changes: 4 additions & 0 deletions src/plugins/template/backend/ops/ops_evaluates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,3 +552,7 @@ extern template bool evaluate_node<ov::op::v15::StringTensorUnpack>(std::shared_
extern template bool evaluate_node<ov::op::v15::StringTensorPack>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);

extern template bool evaluate_node<ov::op::v15::SearchSorted>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);
50 changes: 50 additions & 0 deletions src/plugins/template/backend/ops/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/reference/search_sorted.hpp"

#include "evaluate_node.hpp"

template <ov::element::Type_t ET>
bool evaluate(const std::shared_ptr<ov::op::v15::SearchSorted>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
using T = typename ov::element_type_traits<ET>::value_type;
ov::reference::search_sorted<T>(inputs[0].data<const T>(),
inputs[1].data<const T>(),
outputs[0].data<int64_t>(),
op->get_input_shape(0),
op->get_input_shape(1),
op->get_right_mode());
return true;
}

template <>
bool evaluate_node<ov::op::v15::SearchSorted>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
const auto& element_type = node->get_input_element_type(0);

#define CASE(type) \
case ov::element::type: \
return evaluate<ov::element::type>(ov::as_type_ptr<ov::op::v15::SearchSorted>(node), outputs, inputs);

switch (element_type) {
CASE(bf16);
CASE(f16);
CASE(f32);
CASE(f64);
CASE(i8);
CASE(i16);
CASE(i32);
CASE(i64);
CASE(u8);
CASE(u16);
CASE(u32);
CASE(u64);
default:
OPENVINO_THROW("Unhandled data type ", element_type, " in evaluate_node()");
}
#undef CASE
}
1 change: 1 addition & 0 deletions src/plugins/template/backend/opset_int_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ _OPENVINO_OP_REG(StringTensorPack, ov::op::v15)
_OPENVINO_OP_REG(BitwiseLeftShift, ov::op::v15)
_OPENVINO_OP_REG(BitwiseRightShift, ov::op::v15)
_OPENVINO_OP_REG(SliceScatter, ov::op::v15)
_OPENVINO_OP_REG(SearchSorted, ov::op::v15)

_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)
Expand Down
123 changes: 123 additions & 0 deletions src/plugins/template/tests/functional/op_reference/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/search_sorted.hpp"

#include <gtest/gtest.h>

#include "base_reference_test.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/parameter.hpp"

using namespace reference_tests;
using namespace ov;

namespace {

struct SearchSortedParams {
PartialShape sortedShape;
PartialShape valuesShape;
bool rightMode;
std::string testcaseName;
reference_tests::Tensor sorted;
reference_tests::Tensor values;
reference_tests::Tensor expectedOutput;
};

template <typename T>
SearchSortedParams PrepareTestCaseParams(const PartialShape& sortedShape,
const PartialShape& valuesShape,
bool rightMode,
const std::vector<T>& sortedData,
const std::vector<T>& valuesData,
const std::vector<int64_t>& expectedData,
const std::string& testcaseName) {
SearchSortedParams ret;
const auto elementType = element::from<T>();

ret.sortedShape = sortedShape;
ret.valuesShape = valuesShape;
ret.rightMode = rightMode;
ret.testcaseName = testcaseName;
ret.sorted = reference_tests::Tensor(elementType, sortedShape.to_shape(), sortedData);
ret.values = reference_tests::Tensor(elementType, valuesShape.to_shape(), valuesData);
ret.expectedOutput = reference_tests::Tensor(element::Type_t::i64, valuesShape.to_shape(), expectedData);

return ret;
}

class ReferenceSearchSortedTest : public testing::TestWithParam<SearchSortedParams>, public CommonReferenceTest {
public:
void SetUp() override {
const auto& params = GetParam();
function = CreateFunction(params);
inputData = {params.sorted.data, params.values.data};
refOutData = {params.expectedOutput.data};
}

static std::string getTestCaseName(const testing::TestParamInfo<SearchSortedParams>& obj) {
auto param = obj.param;
std::ostringstream result;
result << "type=" << param.sorted.data.get_element_type();
result << "_sortedShape=" << param.sortedShape;
result << "_valuesShape=" << param.valuesShape;
result << "_rightMode=" << param.rightMode;
result << "_=" << param.testcaseName;

return result.str();
}

private:
static std::shared_ptr<Model> CreateFunction(const SearchSortedParams& params) {
const auto sorted =
std::make_shared<op::v0::Parameter>(params.sorted.data.get_element_type(), params.sortedShape);
const auto values =
std::make_shared<op::v0::Parameter>(params.values.data.get_element_type(), params.valuesShape);

const auto op = std::make_shared<op::v15::SearchSorted>(sorted, values, params.rightMode);

return std::make_shared<Model>(NodeVector{op}, ParameterVector{sorted, values});
}
};

TEST_P(ReferenceSearchSortedTest, CompareWithRefs) {
Exec();
}

template <element::Type_t ET>
std::vector<SearchSortedParams> generateParams() {
using T = typename element_type_traits<ET>::value_type;
std::vector<SearchSortedParams> params;

#define TEST_DATA(sorted_shape, values_shape, right_mode, sorted_data, values_data, expected_output_data, description) \
params.push_back(PrepareTestCaseParams<T>(sorted_shape, \
values_shape, \
right_mode, \
sorted_data, \
values_data, \
expected_output_data, \
description));

#include "unit_test_utils/tests_data/search_sorted_data.h"
#undef TEST_DATA

return params;
}

std::vector<SearchSortedParams> generateCombinedParams() {
const std::vector<std::vector<SearchSortedParams>> generatedParams{generateParams<element::Type_t::i32>(),
generateParams<element::Type_t::f32>()};
std::vector<SearchSortedParams> combinedParams;

for (const auto& params : generatedParams) {
combinedParams.insert(combinedParams.end(), params.begin(), params.end());
}
return combinedParams;
}

INSTANTIATE_TEST_SUITE_P(smoke_SearchSorted_With_Hardcoded_Refs,
ReferenceSearchSortedTest,
testing::ValuesIn(generateCombinedParams()),
ReferenceSearchSortedTest::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -2066,6 +2066,15 @@ std::shared_ptr<ov::Model> generateRNNCellBase(const std::shared_ptr<ov::op::Op>
}
}

std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v15::SearchSorted>& node) {
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{16})};
const auto values =
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 3}, std::vector<float>(6, 0));
auto new_node = std::make_shared<ov::op::v15::SearchSorted>(params.at(0), values);
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(new_node)};
return std::make_shared<ov::Model>(results, params, "SearchSortedGraph");
}

std::shared_ptr<ov::Model> generateSubGraphOp(const std::shared_ptr<ov::op::Op> &node) {
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{{2, 2}}),
std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{{2, 2}}),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#pragma once

#define LIST(...) \
{ __VA_ARGS__ }

// TEST_DATA(sorted_shape,
// values_shape,
// right_mode,
// sorted_data,
// values_data,
// expected_output_data,
// description)

// NOTE: expected output were generated using pyTorch.searchsorted implementation.

TEST_DATA(LIST(5),
LIST(2, 3),
false,
LIST(1, 3, 5, 7, 9),
LIST(3, 6, 9, 3, 6, 9),
LIST(1, 3, 4, 1, 3, 4),
"1d_tensor_1");

TEST_DATA(LIST(5),
LIST(4, 3),
false,
LIST(1, 3, 5, 7, 9),
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20),
LIST(0, 3, 5, 0, 3, 4, 0, 0, 0, 4, 5, 5),
"1d_tensor_2");

TEST_DATA(LIST(5),
LIST(4, 3),
true,
LIST(1, 3, 5, 7, 9),
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20),
LIST(0, 3, 5, 1, 3, 5, 1, 0, 0, 5, 5, 5),
"1d_tensor_2_right_mode");

TEST_DATA(LIST(5),
LIST(2, 2, 3),
false,
LIST(1, 3, 5, 7, 9),
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20),
LIST(0, 3, 5, 0, 3, 4, 0, 0, 0, 4, 5, 5),
"1d_tensor_3");

TEST_DATA(LIST(5),
LIST(2, 2, 3),
true,
LIST(1, 3, 5, 7, 9),
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20),
LIST(0, 3, 5, 1, 3, 5, 1, 0, 0, 5, 5, 5),
"1d_tensor_3_right_mode");

TEST_DATA(LIST(2, 5),
LIST(2, 3),
false,
LIST(1, 3, 5, 7, 9, 2, 4, 6, 8, 10),
LIST(3, 6, 9, 3, 6, 9),
LIST(1, 3, 4, 1, 2, 4),
"nd_tensor_1");

TEST_DATA(LIST(2, 5),
LIST(2, 3),
true,
LIST(1, 3, 5, 7, 9, 2, 4, 6, 8, 10),
LIST(3, 6, 9, 3, 6, 9),
LIST(2, 3, 5, 1, 3, 4),
"nd_tensor_1_right_mode");

TEST_DATA(LIST(2, 2, 5),
LIST(2, 2, 3),
false,
LIST(1, 3, 5, 7, 9, 0, 2, 4, 6, 8, -20, 5, 10, 23, 41, 100, 125, 130, 132, 139),
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20),
LIST(0, 3, 5, 1, 3, 5, 1, 1, 1, 0, 0, 0),
"nd_tensor_2");

TEST_DATA(LIST(2, 2, 5),
LIST(2, 2, 3),
true,
LIST(1, 3, 5, 7, 9, 0, 2, 4, 6, 8, -20, 5, 10, 23, 41, 100, 125, 130, 132, 139),
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20),
LIST(0, 3, 5, 1, 4, 5, 1, 1, 1, 0, 0, 0),
"nd_tensor_2");

0 comments on commit b212815

Please sign in to comment.