-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ref]: Added SearchSorted ref impl (#26958)
### Details: - Added ref implementation of search sorted op ### Tickets: - *CVS-154061* Depends on: #26904 --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com> Co-authored-by: Pawel Raasz <pawel.raasz@intel.com> Co-authored-by: Andrey Babushkin <andrey.babushkin@intel.com> Co-authored-by: Alicja Miloszewska <alicja.miloszewska@intel.com> Co-authored-by: Bogdan Pereanu <bogdan.pereanu@intel.com> Co-authored-by: Karol Blaszczak <karol.blaszczak@intel.com> Co-authored-by: Tatiana Savina <tatiana.savina@intel.com> Co-authored-by: Anastasiya(Asya) Pronina <anastasiya.pronina@intel.com> Co-authored-by: Dmitry Matveev <dmitry.matveev@intel.com> Co-authored-by: Andrei Beleiu <andrei-marin.beleiu@intel.com> Co-authored-by: Andrew Kwangwoong Park <andrew.park@intel.com> Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com> Co-authored-by: Pavel Durandin <pavel.durandin@intel.com> Co-authored-by: Alexey Smirnov <alexey.smirnov@intel.com> Co-authored-by: Hubert Błaszczyk <56601011+hub-bla@users.noreply.github.com> Co-authored-by: Vladimir Paramuzov <vladimir.paramuzov@intel.com> Co-authored-by: Sergey Shlyapnikov <sergey.shlyapnikov@intel.com> Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com> Co-authored-by: Andrzej Kopytko <andrzejx.kopytko@intel.com> Co-authored-by: Sebastian Golebiewski <sebastianx.golebiewski@intel.com> Co-authored-by: Alina Kladieva <alina.kladieva@intel.com> Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com> Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
- Loading branch information
1 parent
79e229b
commit b212815
Showing
9 changed files
with
330 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
src/core/reference/include/openvino/reference/search_sorted.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/core/shape.hpp" | ||
#include "openvino/reference/utils/coordinate_index.hpp" | ||
#include "openvino/reference/utils/coordinate_transform.hpp" | ||
|
||
namespace ov { | ||
namespace reference { | ||
template <typename T, typename TOut = int64_t> | ||
void search_sorted(const T* sorted, | ||
const T* values, | ||
TOut* out, | ||
const Shape& sorted_shape, | ||
const Shape& values_shape, | ||
bool right_mode) { | ||
const CoordinateTransformBasic values_transform{values_shape}; | ||
|
||
std::function<const T*(const T*, const T*, T)> compare_func = nullptr; | ||
if (right_mode) { | ||
compare_func = [](const T* begin, const T* end, T value) { | ||
return std::lower_bound(begin, end, value, std::less_equal<T>()); | ||
}; | ||
} else { | ||
compare_func = [](const T* begin, const T* end, T value) { | ||
return std::lower_bound(begin, end, value, std::less<T>()); | ||
}; | ||
} | ||
|
||
for (const Coordinate& values_coord : values_transform) { | ||
const auto values_index = coordinate_index(values_coord, values_shape); | ||
const T value = values[values_index]; | ||
|
||
Coordinate sorted_coord_begin = values_coord; | ||
sorted_coord_begin.back() = 0; | ||
|
||
Coordinate sorted_coord_last = values_coord; | ||
sorted_coord_last.back() = sorted_shape.back(); | ||
|
||
const auto sorted_index_begin = coordinate_index(sorted_coord_begin, sorted_shape); | ||
const auto sorted_index_last = coordinate_index(sorted_coord_last, sorted_shape); | ||
|
||
const T* idx_ptr = compare_func(sorted + sorted_index_begin, sorted + sorted_index_last, value); | ||
|
||
const ptrdiff_t sorted_index = (idx_ptr - sorted) - sorted_index_begin; | ||
|
||
out[values_index] = static_cast<TOut>(sorted_index); | ||
} | ||
} | ||
|
||
} // namespace reference | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/reference/search_sorted.hpp" | ||
|
||
#include "evaluate_node.hpp" | ||
|
||
template <ov::element::Type_t ET> | ||
bool evaluate(const std::shared_ptr<ov::op::v15::SearchSorted>& op, | ||
ov::TensorVector& outputs, | ||
const ov::TensorVector& inputs) { | ||
using T = typename ov::element_type_traits<ET>::value_type; | ||
ov::reference::search_sorted<T>(inputs[0].data<const T>(), | ||
inputs[1].data<const T>(), | ||
outputs[0].data<int64_t>(), | ||
op->get_input_shape(0), | ||
op->get_input_shape(1), | ||
op->get_right_mode()); | ||
return true; | ||
} | ||
|
||
template <> | ||
bool evaluate_node<ov::op::v15::SearchSorted>(std::shared_ptr<ov::Node> node, | ||
ov::TensorVector& outputs, | ||
const ov::TensorVector& inputs) { | ||
const auto& element_type = node->get_input_element_type(0); | ||
|
||
#define CASE(type) \ | ||
case ov::element::type: \ | ||
return evaluate<ov::element::type>(ov::as_type_ptr<ov::op::v15::SearchSorted>(node), outputs, inputs); | ||
|
||
switch (element_type) { | ||
CASE(bf16); | ||
CASE(f16); | ||
CASE(f32); | ||
CASE(f64); | ||
CASE(i8); | ||
CASE(i16); | ||
CASE(i32); | ||
CASE(i64); | ||
CASE(u8); | ||
CASE(u16); | ||
CASE(u32); | ||
CASE(u64); | ||
default: | ||
OPENVINO_THROW("Unhandled data type ", element_type, " in evaluate_node()"); | ||
} | ||
#undef CASE | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
123 changes: 123 additions & 0 deletions
123
src/plugins/template/tests/functional/op_reference/search_sorted.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/op/search_sorted.hpp" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include "base_reference_test.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/parameter.hpp" | ||
|
||
using namespace reference_tests; | ||
using namespace ov; | ||
|
||
namespace { | ||
|
||
struct SearchSortedParams { | ||
PartialShape sortedShape; | ||
PartialShape valuesShape; | ||
bool rightMode; | ||
std::string testcaseName; | ||
reference_tests::Tensor sorted; | ||
reference_tests::Tensor values; | ||
reference_tests::Tensor expectedOutput; | ||
}; | ||
|
||
template <typename T> | ||
SearchSortedParams PrepareTestCaseParams(const PartialShape& sortedShape, | ||
const PartialShape& valuesShape, | ||
bool rightMode, | ||
const std::vector<T>& sortedData, | ||
const std::vector<T>& valuesData, | ||
const std::vector<int64_t>& expectedData, | ||
const std::string& testcaseName) { | ||
SearchSortedParams ret; | ||
const auto elementType = element::from<T>(); | ||
|
||
ret.sortedShape = sortedShape; | ||
ret.valuesShape = valuesShape; | ||
ret.rightMode = rightMode; | ||
ret.testcaseName = testcaseName; | ||
ret.sorted = reference_tests::Tensor(elementType, sortedShape.to_shape(), sortedData); | ||
ret.values = reference_tests::Tensor(elementType, valuesShape.to_shape(), valuesData); | ||
ret.expectedOutput = reference_tests::Tensor(element::Type_t::i64, valuesShape.to_shape(), expectedData); | ||
|
||
return ret; | ||
} | ||
|
||
class ReferenceSearchSortedTest : public testing::TestWithParam<SearchSortedParams>, public CommonReferenceTest { | ||
public: | ||
void SetUp() override { | ||
const auto& params = GetParam(); | ||
function = CreateFunction(params); | ||
inputData = {params.sorted.data, params.values.data}; | ||
refOutData = {params.expectedOutput.data}; | ||
} | ||
|
||
static std::string getTestCaseName(const testing::TestParamInfo<SearchSortedParams>& obj) { | ||
auto param = obj.param; | ||
std::ostringstream result; | ||
result << "type=" << param.sorted.data.get_element_type(); | ||
result << "_sortedShape=" << param.sortedShape; | ||
result << "_valuesShape=" << param.valuesShape; | ||
result << "_rightMode=" << param.rightMode; | ||
result << "_=" << param.testcaseName; | ||
|
||
return result.str(); | ||
} | ||
|
||
private: | ||
static std::shared_ptr<Model> CreateFunction(const SearchSortedParams& params) { | ||
const auto sorted = | ||
std::make_shared<op::v0::Parameter>(params.sorted.data.get_element_type(), params.sortedShape); | ||
const auto values = | ||
std::make_shared<op::v0::Parameter>(params.values.data.get_element_type(), params.valuesShape); | ||
|
||
const auto op = std::make_shared<op::v15::SearchSorted>(sorted, values, params.rightMode); | ||
|
||
return std::make_shared<Model>(NodeVector{op}, ParameterVector{sorted, values}); | ||
} | ||
}; | ||
|
||
TEST_P(ReferenceSearchSortedTest, CompareWithRefs) { | ||
Exec(); | ||
} | ||
|
||
template <element::Type_t ET> | ||
std::vector<SearchSortedParams> generateParams() { | ||
using T = typename element_type_traits<ET>::value_type; | ||
std::vector<SearchSortedParams> params; | ||
|
||
#define TEST_DATA(sorted_shape, values_shape, right_mode, sorted_data, values_data, expected_output_data, description) \ | ||
params.push_back(PrepareTestCaseParams<T>(sorted_shape, \ | ||
values_shape, \ | ||
right_mode, \ | ||
sorted_data, \ | ||
values_data, \ | ||
expected_output_data, \ | ||
description)); | ||
|
||
#include "unit_test_utils/tests_data/search_sorted_data.h" | ||
#undef TEST_DATA | ||
|
||
return params; | ||
} | ||
|
||
std::vector<SearchSortedParams> generateCombinedParams() { | ||
const std::vector<std::vector<SearchSortedParams>> generatedParams{generateParams<element::Type_t::i32>(), | ||
generateParams<element::Type_t::f32>()}; | ||
std::vector<SearchSortedParams> combinedParams; | ||
|
||
for (const auto& params : generatedParams) { | ||
combinedParams.insert(combinedParams.end(), params.begin(), params.end()); | ||
} | ||
return combinedParams; | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_SearchSorted_With_Hardcoded_Refs, | ||
ReferenceSearchSortedTest, | ||
testing::ValuesIn(generateCombinedParams()), | ||
ReferenceSearchSortedTest::getTestCaseName); | ||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
86 changes: 86 additions & 0 deletions
86
src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#pragma once | ||
|
||
#define LIST(...) \ | ||
{ __VA_ARGS__ } | ||
|
||
// TEST_DATA(sorted_shape, | ||
// values_shape, | ||
// right_mode, | ||
// sorted_data, | ||
// values_data, | ||
// expected_output_data, | ||
// description) | ||
|
||
// NOTE: expected output were generated using pyTorch.searchsorted implementation. | ||
|
||
TEST_DATA(LIST(5), | ||
LIST(2, 3), | ||
false, | ||
LIST(1, 3, 5, 7, 9), | ||
LIST(3, 6, 9, 3, 6, 9), | ||
LIST(1, 3, 4, 1, 3, 4), | ||
"1d_tensor_1"); | ||
|
||
TEST_DATA(LIST(5), | ||
LIST(4, 3), | ||
false, | ||
LIST(1, 3, 5, 7, 9), | ||
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), | ||
LIST(0, 3, 5, 0, 3, 4, 0, 0, 0, 4, 5, 5), | ||
"1d_tensor_2"); | ||
|
||
TEST_DATA(LIST(5), | ||
LIST(4, 3), | ||
true, | ||
LIST(1, 3, 5, 7, 9), | ||
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), | ||
LIST(0, 3, 5, 1, 3, 5, 1, 0, 0, 5, 5, 5), | ||
"1d_tensor_2_right_mode"); | ||
|
||
TEST_DATA(LIST(5), | ||
LIST(2, 2, 3), | ||
false, | ||
LIST(1, 3, 5, 7, 9), | ||
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), | ||
LIST(0, 3, 5, 0, 3, 4, 0, 0, 0, 4, 5, 5), | ||
"1d_tensor_3"); | ||
|
||
TEST_DATA(LIST(5), | ||
LIST(2, 2, 3), | ||
true, | ||
LIST(1, 3, 5, 7, 9), | ||
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), | ||
LIST(0, 3, 5, 1, 3, 5, 1, 0, 0, 5, 5, 5), | ||
"1d_tensor_3_right_mode"); | ||
|
||
TEST_DATA(LIST(2, 5), | ||
LIST(2, 3), | ||
false, | ||
LIST(1, 3, 5, 7, 9, 2, 4, 6, 8, 10), | ||
LIST(3, 6, 9, 3, 6, 9), | ||
LIST(1, 3, 4, 1, 2, 4), | ||
"nd_tensor_1"); | ||
|
||
TEST_DATA(LIST(2, 5), | ||
LIST(2, 3), | ||
true, | ||
LIST(1, 3, 5, 7, 9, 2, 4, 6, 8, 10), | ||
LIST(3, 6, 9, 3, 6, 9), | ||
LIST(2, 3, 5, 1, 3, 4), | ||
"nd_tensor_1_right_mode"); | ||
|
||
TEST_DATA(LIST(2, 2, 5), | ||
LIST(2, 2, 3), | ||
false, | ||
LIST(1, 3, 5, 7, 9, 0, 2, 4, 6, 8, -20, 5, 10, 23, 41, 100, 125, 130, 132, 139), | ||
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), | ||
LIST(0, 3, 5, 1, 3, 5, 1, 1, 1, 0, 0, 0), | ||
"nd_tensor_2"); | ||
|
||
TEST_DATA(LIST(2, 2, 5), | ||
LIST(2, 2, 3), | ||
true, | ||
LIST(1, 3, 5, 7, 9, 0, 2, 4, 6, 8, -20, 5, 10, 23, 41, 100, 125, 130, 132, 139), | ||
LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), | ||
LIST(0, 3, 5, 1, 4, 5, 1, 1, 1, 0, 0, 0), | ||
"nd_tensor_2"); |