diff --git a/ydb/library/yql/udfs/common/knn/knn-distance.h b/ydb/library/yql/udfs/common/knn/knn-distance.h index 8c375b935616..cf319efaf0a9 100644 --- a/ydb/library/yql/udfs/common/knn/knn-distance.h +++ b/ydb/library/yql/udfs/common/knn/knn-distance.h @@ -269,4 +269,55 @@ class KnnDistance { return {}; } } + + static TDistanceResult HammingDistance(const TStringBuf& str1, const TStringBuf& str2) { + const ui8 format1 = str1.Data()[str1.Size() - HeaderLen]; + const ui8 format2 = str2.Data()[str2.Size() - HeaderLen]; + if (Y_UNLIKELY(format1 != format2)) { + return {}; + } + + switch (format1) { + case EFormat::FloatVector: + return VectorFunc(str1, str2, [](const float* v1, const float* v2, size_t len) { + ui64 ret = 0; + for (size_t i = 0; i < len; ++i) { + if (v1[i] != v2[i]) { + ++ret; + } + } + return ret; + }); + case EFormat::Int8Vector: + return VectorFunc(str1, str2, [](const i8* v1, const i8* v2, size_t len) { + ui64 ret = 0; + for (size_t i = 0; i < len; ++i) { + if (v1[i] != v2[i]) { + ++ret; + } + } + return ret; + }); + case EFormat::Uint8Vector: + return VectorFunc(str1, str2, [](const ui8* v1, const ui8* v2, size_t len) { + ui64 ret = 0; + for (size_t i = 0; i < len; ++i) { + if (v1[i] != v2[i]) { + ++ret; + } + } + return ret; + }); + case EFormat::BitVector: + return BitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 bitLen) { + ui64 ret = 0; + BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) { + ret += std::popcount(d1 ^ d2); + }); + return ret; + }); + default: + return {}; + } + } }; diff --git a/ydb/library/yql/udfs/common/knn/knn.cpp b/ydb/library/yql/udfs/common/knn/knn.cpp index 2810a3006c11..d7abf84732a2 100644 --- a/ydb/library/yql/udfs/common/knn/knn.cpp +++ b/ydb/library/yql/udfs/common/knn/knn.cpp @@ -408,6 +408,25 @@ class TEuclideanDistance: public TDistanceBase { } }; +class THammingDistance: public TDistanceBase { +public: + using TDistanceBase::TDistanceBase; + + static const TStringRef& Name() { + static auto name = TStringRef::Of("HammingDistance"); + return name; + } + + TUnboxedValue RunImpl(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const { + Y_UNUSED(valueBuilder); + const auto ret = KnnDistance<>::HammingDistance(args[0].AsStringRef(), args[1].AsStringRef()); + if (Y_UNLIKELY(!ret)) { + return {}; + } + return TUnboxedValuePod{*ret}; + } +}; + // TODO IR for Distance functions? SIMPLE_MODULE(TKnnModule, @@ -420,6 +439,7 @@ SIMPLE_MODULE(TKnnModule, TCosineSimilarity, TCosineDistance, TManhattanDistance, - TEuclideanDistance) + TEuclideanDistance, + THammingDistance) REGISTER_MODULES(TKnnModule) diff --git a/ydb/library/yql/udfs/common/knn/test/cases/HammingDistance.sql b/ydb/library/yql/udfs/common/knn/test/cases/HammingDistance.sql new file mode 100644 index 000000000000..eb9173e943f2 --- /dev/null +++ b/ydb/library/yql/udfs/common/knn/test/cases/HammingDistance.sql @@ -0,0 +1,187 @@ +--fixed size vector +$vector1 = Knn::ToBinaryStringFloat([1.0f, 2.0f, 3.0f]); +$vector2 = Knn::ToBinaryStringFloat([4.0f, 5.0f, 6.0f]); +select Knn::HammingDistance($vector1, $vector2); + +--exact vectors +select Knn::HammingDistance($vector1, $vector1); + +--orthogonal vectors +$orthogonal_vector1 = Knn::ToBinaryStringUint8([1ut, 0ut]); +$orthogonal_vector2 = Knn::ToBinaryStringUint8([0ut, 2ut]); +select Knn::HammingDistance($orthogonal_vector1, $orthogonal_vector2); + +--float vector +$float_vector1 = Knn::ToBinaryStringFloat([1.0f, 2.0f, 3.0f]); +$float_vector2 = Knn::ToBinaryStringFloat([4.0f, 5.0f, 6.0f]); +select Knn::HammingDistance($float_vector1, $float_vector2); + +--float vector with some equal elements +$float_vector3 = Knn::ToBinaryStringFloat([1.0f, 2.0f, 3.0f, 4.0f]); +$float_vector4 = Knn::ToBinaryStringFloat([1.0f, 5.0f, 3.0f, 7.0f]); +select Knn::HammingDistance($float_vector3, $float_vector4); + +--byte vector +$byte_vector1 = Knn::ToBinaryStringUint8([1ut, 2ut, 3ut]); +$byte_vector2 = Knn::ToBinaryStringUint8([4ut, 5ut, 6ut]); +select Knn::HammingDistance($byte_vector1, $byte_vector2); + +--byte vector with some equal elements +$byte_vector3 = Knn::ToBinaryStringUint8([1ut, 2ut, 3ut, 4ut]); +$byte_vector4 = Knn::ToBinaryStringUint8([1ut, 5ut, 3ut, 7ut]); +select Knn::HammingDistance($byte_vector3, $byte_vector4); + +--int8 vector +$int8_vector1 = Knn::ToBinaryStringInt8([1t, 2t, 3t]); +$int8_vector2 = Knn::ToBinaryStringInt8([4t, 5t, 6t]); +select Knn::HammingDistance($int8_vector1, $int8_vector2); + +--int8 vector with some equal elements +$int8_vector3 = Knn::ToBinaryStringInt8([1t, 2t, 3t, 4t]); +$int8_vector4 = Knn::ToBinaryStringInt8([1t, 5t, 3t, 7t]); +select Knn::HammingDistance($int8_vector3, $int8_vector4); + +--bit vector +$bitvector_positive = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64)); +$bitvector_positive_double_size = Knn::ToBinaryStringBit(ListReplicate(1.0f, 128)); +$bitvector_negative = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64)); +$bitvector_negative_and_positive = Knn::ToBinaryStringBit(ListFromRange(-63.0f, 64.1f)); +$bitvector_negative_and_positive_striped = Knn::ToBinaryStringBit(ListFlatten(ListReplicate([-1.0f, 1.0f], 32))); + +select Knn::HammingDistance($bitvector_positive, $bitvector_negative); +select Knn::HammingDistance($bitvector_positive_double_size, $bitvector_negative_and_positive); +select Knn::HammingDistance($bitvector_positive, $bitvector_negative_and_positive_striped); + +--bit vector -- with tail +$bitvector_pos_1_00 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 00)); +$bitvector_pos_1_04 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 04)); +$bitvector_pos_1_08 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 08)); +$bitvector_pos_1_16 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 16)); +$bitvector_pos_1_24 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 24)); +$bitvector_pos_1_32 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 32)); +$bitvector_pos_1_40 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 40)); +$bitvector_pos_1_48 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 48)); +$bitvector_pos_1_56 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 56)); +$bitvector_pos_1_60 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 60)); + +$bitvector_neg_1_00 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 00)); +$bitvector_neg_1_04 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 04)); +$bitvector_neg_1_08 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 08)); +$bitvector_neg_1_16 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 16)); +$bitvector_neg_1_24 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 24)); +$bitvector_neg_1_32 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 32)); +$bitvector_neg_1_40 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 40)); +$bitvector_neg_1_48 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 48)); +$bitvector_neg_1_56 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 56)); +$bitvector_neg_1_60 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 60)); + +select 64 + 00, Knn::HammingDistance($bitvector_pos_1_00, $bitvector_pos_1_00); +select 64 + 04, Knn::HammingDistance($bitvector_pos_1_04, $bitvector_pos_1_04); +select 64 + 08, Knn::HammingDistance($bitvector_pos_1_08, $bitvector_pos_1_08); +select 64 + 16, Knn::HammingDistance($bitvector_pos_1_16, $bitvector_pos_1_16); +select 64 + 24, Knn::HammingDistance($bitvector_pos_1_24, $bitvector_pos_1_24); +select 64 + 32, Knn::HammingDistance($bitvector_pos_1_32, $bitvector_pos_1_32); +select 64 + 40, Knn::HammingDistance($bitvector_pos_1_40, $bitvector_pos_1_40); +select 64 + 48, Knn::HammingDistance($bitvector_pos_1_48, $bitvector_pos_1_48); +select 64 + 56, Knn::HammingDistance($bitvector_pos_1_56, $bitvector_pos_1_56); +select 64 + 60, Knn::HammingDistance($bitvector_pos_1_60, $bitvector_pos_1_60); + +select 64 + 00, Knn::HammingDistance($bitvector_neg_1_00, $bitvector_neg_1_00); +select 64 + 04, Knn::HammingDistance($bitvector_neg_1_04, $bitvector_neg_1_04); +select 64 + 08, Knn::HammingDistance($bitvector_neg_1_08, $bitvector_neg_1_08); +select 64 + 16, Knn::HammingDistance($bitvector_neg_1_16, $bitvector_neg_1_16); +select 64 + 24, Knn::HammingDistance($bitvector_neg_1_24, $bitvector_neg_1_24); +select 64 + 32, Knn::HammingDistance($bitvector_neg_1_32, $bitvector_neg_1_32); +select 64 + 40, Knn::HammingDistance($bitvector_neg_1_40, $bitvector_neg_1_40); +select 64 + 48, Knn::HammingDistance($bitvector_neg_1_48, $bitvector_neg_1_48); +select 64 + 56, Knn::HammingDistance($bitvector_neg_1_56, $bitvector_neg_1_56); +select 64 + 60, Knn::HammingDistance($bitvector_neg_1_60, $bitvector_neg_1_60); + +select 64 + 00, Knn::HammingDistance($bitvector_pos_1_00, $bitvector_neg_1_00); +select 64 + 04, Knn::HammingDistance($bitvector_pos_1_04, $bitvector_neg_1_04); +select 64 + 08, Knn::HammingDistance($bitvector_pos_1_08, $bitvector_neg_1_08); +select 64 + 16, Knn::HammingDistance($bitvector_pos_1_16, $bitvector_neg_1_16); +select 64 + 24, Knn::HammingDistance($bitvector_pos_1_24, $bitvector_neg_1_24); +select 64 + 32, Knn::HammingDistance($bitvector_pos_1_32, $bitvector_neg_1_32); +select 64 + 40, Knn::HammingDistance($bitvector_pos_1_40, $bitvector_neg_1_40); +select 64 + 48, Knn::HammingDistance($bitvector_pos_1_48, $bitvector_neg_1_48); +select 64 + 56, Knn::HammingDistance($bitvector_pos_1_56, $bitvector_neg_1_56); +select 64 + 60, Knn::HammingDistance($bitvector_pos_1_60, $bitvector_neg_1_60); + +select 64 + 00, Knn::HammingDistance($bitvector_neg_1_00, $bitvector_pos_1_00); +select 64 + 04, Knn::HammingDistance($bitvector_neg_1_04, $bitvector_pos_1_04); +select 64 + 08, Knn::HammingDistance($bitvector_neg_1_08, $bitvector_pos_1_08); +select 64 + 16, Knn::HammingDistance($bitvector_neg_1_16, $bitvector_pos_1_16); +select 64 + 24, Knn::HammingDistance($bitvector_neg_1_24, $bitvector_pos_1_24); +select 64 + 32, Knn::HammingDistance($bitvector_neg_1_32, $bitvector_pos_1_32); +select 64 + 40, Knn::HammingDistance($bitvector_neg_1_40, $bitvector_pos_1_40); +select 64 + 48, Knn::HammingDistance($bitvector_neg_1_48, $bitvector_pos_1_48); +select 64 + 56, Knn::HammingDistance($bitvector_neg_1_56, $bitvector_pos_1_56); +select 64 + 60, Knn::HammingDistance($bitvector_neg_1_60, $bitvector_pos_1_60); + +--bit vector -- only tail +$bitvector_pos_00 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 00)); +$bitvector_pos_04 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 04)); +$bitvector_pos_08 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 08)); +$bitvector_pos_16 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 16)); +$bitvector_pos_24 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 24)); +$bitvector_pos_32 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 32)); +$bitvector_pos_40 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 40)); +$bitvector_pos_48 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 48)); +$bitvector_pos_56 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 56)); +$bitvector_pos_60 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 60)); + +$bitvector_neg_00 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 00)); +$bitvector_neg_04 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 04)); +$bitvector_neg_08 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 08)); +$bitvector_neg_16 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 16)); +$bitvector_neg_24 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 24)); +$bitvector_neg_32 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 32)); +$bitvector_neg_40 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 40)); +$bitvector_neg_48 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 48)); +$bitvector_neg_56 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 56)); +$bitvector_neg_60 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 60)); + +select 00, Knn::HammingDistance($bitvector_pos_00, $bitvector_pos_00); +select 04, Knn::HammingDistance($bitvector_pos_04, $bitvector_pos_04); +select 08, Knn::HammingDistance($bitvector_pos_08, $bitvector_pos_08); +select 16, Knn::HammingDistance($bitvector_pos_16, $bitvector_pos_16); +select 24, Knn::HammingDistance($bitvector_pos_24, $bitvector_pos_24); +select 32, Knn::HammingDistance($bitvector_pos_32, $bitvector_pos_32); +select 40, Knn::HammingDistance($bitvector_pos_40, $bitvector_pos_40); +select 48, Knn::HammingDistance($bitvector_pos_48, $bitvector_pos_48); +select 56, Knn::HammingDistance($bitvector_pos_56, $bitvector_pos_56); +select 60, Knn::HammingDistance($bitvector_pos_60, $bitvector_pos_60); + +select 00, Knn::HammingDistance($bitvector_neg_00, $bitvector_neg_00); +select 04, Knn::HammingDistance($bitvector_neg_04, $bitvector_neg_04); +select 08, Knn::HammingDistance($bitvector_neg_08, $bitvector_neg_08); +select 16, Knn::HammingDistance($bitvector_neg_16, $bitvector_neg_16); +select 24, Knn::HammingDistance($bitvector_neg_24, $bitvector_neg_24); +select 32, Knn::HammingDistance($bitvector_neg_32, $bitvector_neg_32); +select 40, Knn::HammingDistance($bitvector_neg_40, $bitvector_neg_40); +select 48, Knn::HammingDistance($bitvector_neg_48, $bitvector_neg_48); +select 56, Knn::HammingDistance($bitvector_neg_56, $bitvector_neg_56); +select 60, Knn::HammingDistance($bitvector_neg_60, $bitvector_neg_60); + +select 00, Knn::HammingDistance($bitvector_pos_00, $bitvector_neg_00); +select 04, Knn::HammingDistance($bitvector_pos_04, $bitvector_neg_04); +select 08, Knn::HammingDistance($bitvector_pos_08, $bitvector_neg_08); +select 16, Knn::HammingDistance($bitvector_pos_16, $bitvector_neg_16); +select 24, Knn::HammingDistance($bitvector_pos_24, $bitvector_neg_24); +select 32, Knn::HammingDistance($bitvector_pos_32, $bitvector_neg_32); +select 40, Knn::HammingDistance($bitvector_pos_40, $bitvector_neg_40); +select 48, Knn::HammingDistance($bitvector_pos_48, $bitvector_neg_48); +select 56, Knn::HammingDistance($bitvector_pos_56, $bitvector_neg_56); +select 60, Knn::HammingDistance($bitvector_pos_60, $bitvector_neg_60); + +select 00, Knn::HammingDistance($bitvector_neg_00, $bitvector_pos_00); +select 04, Knn::HammingDistance($bitvector_neg_04, $bitvector_pos_04); +select 08, Knn::HammingDistance($bitvector_neg_08, $bitvector_pos_08); +select 16, Knn::HammingDistance($bitvector_neg_16, $bitvector_pos_16); +select 24, Knn::HammingDistance($bitvector_neg_24, $bitvector_pos_24); +select 32, Knn::HammingDistance($bitvector_neg_32, $bitvector_pos_32); +select 40, Knn::HammingDistance($bitvector_neg_40, $bitvector_pos_40); +select 48, Knn::HammingDistance($bitvector_neg_48, $bitvector_pos_48); +select 56, Knn::HammingDistance($bitvector_neg_56, $bitvector_pos_56); +select 60, Knn::HammingDistance($bitvector_neg_60, $bitvector_pos_60);