Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions ydb/library/yql/udfs/common/knn/knn-distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,55 @@ class KnnDistance {
return {};
}
}

static TDistanceResult HammingDistance(const TStringBuf& str1, const TStringBuf& str2) {
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
if (Y_UNLIKELY(format1 != format2)) {
return {};
}

switch (format1) {
case EFormat::FloatVector:
return VectorFunc<float>(str1, str2, [](const float* v1, const float* v2, size_t len) {
ui64 ret = 0;
for (size_t i = 0; i < len; ++i) {
if (v1[i] != v2[i]) {
++ret;
}
}
return ret;
});
case EFormat::Int8Vector:
return VectorFunc<i8>(str1, str2, [](const i8* v1, const i8* v2, size_t len) {
ui64 ret = 0;
for (size_t i = 0; i < len; ++i) {
if (v1[i] != v2[i]) {
++ret;
}
}
return ret;
});
case EFormat::Uint8Vector:
return VectorFunc<ui8>(str1, str2, [](const ui8* v1, const ui8* v2, size_t len) {
ui64 ret = 0;
for (size_t i = 0; i < len; ++i) {
if (v1[i] != v2[i]) {
++ret;
}
}
return ret;
});
case EFormat::BitVector:
return BitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 bitLen) {
ui64 ret = 0;
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
ret += std::popcount(d1 ^ d2);
});
return ret;
});
default:
return {};
}
}
};
22 changes: 21 additions & 1 deletion ydb/library/yql/udfs/common/knn/knn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,25 @@ class TEuclideanDistance: public TDistanceBase<TEuclideanDistance> {
}
};

class THammingDistance: public TDistanceBase<THammingDistance> {
public:
using TDistanceBase::TDistanceBase;

static const TStringRef& Name() {
static auto name = TStringRef::Of("HammingDistance");
return name;
}

TUnboxedValue RunImpl(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const {
Y_UNUSED(valueBuilder);
const auto ret = KnnDistance<>::HammingDistance(args[0].AsStringRef(), args[1].AsStringRef());
if (Y_UNLIKELY(!ret)) {
return {};
}
return TUnboxedValuePod{*ret};
}
};

// TODO IR for Distance functions?

SIMPLE_MODULE(TKnnModule,
Expand All @@ -420,6 +439,7 @@ SIMPLE_MODULE(TKnnModule,
TCosineSimilarity,
TCosineDistance,
TManhattanDistance,
TEuclideanDistance)
TEuclideanDistance,
THammingDistance)

REGISTER_MODULES(TKnnModule)
187 changes: 187 additions & 0 deletions ydb/library/yql/udfs/common/knn/test/cases/HammingDistance.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
--fixed size vector
$vector1 = Knn::ToBinaryStringFloat([1.0f, 2.0f, 3.0f]);
$vector2 = Knn::ToBinaryStringFloat([4.0f, 5.0f, 6.0f]);
select Knn::HammingDistance($vector1, $vector2);

--exact vectors
select Knn::HammingDistance($vector1, $vector1);

--orthogonal vectors
$orthogonal_vector1 = Knn::ToBinaryStringUint8([1ut, 0ut]);
$orthogonal_vector2 = Knn::ToBinaryStringUint8([0ut, 2ut]);
select Knn::HammingDistance($orthogonal_vector1, $orthogonal_vector2);

--float vector
$float_vector1 = Knn::ToBinaryStringFloat([1.0f, 2.0f, 3.0f]);
$float_vector2 = Knn::ToBinaryStringFloat([4.0f, 5.0f, 6.0f]);
select Knn::HammingDistance($float_vector1, $float_vector2);

--float vector with some equal elements
$float_vector3 = Knn::ToBinaryStringFloat([1.0f, 2.0f, 3.0f, 4.0f]);
$float_vector4 = Knn::ToBinaryStringFloat([1.0f, 5.0f, 3.0f, 7.0f]);
select Knn::HammingDistance($float_vector3, $float_vector4);

--byte vector
$byte_vector1 = Knn::ToBinaryStringUint8([1ut, 2ut, 3ut]);
$byte_vector2 = Knn::ToBinaryStringUint8([4ut, 5ut, 6ut]);
select Knn::HammingDistance($byte_vector1, $byte_vector2);

--byte vector with some equal elements
$byte_vector3 = Knn::ToBinaryStringUint8([1ut, 2ut, 3ut, 4ut]);
$byte_vector4 = Knn::ToBinaryStringUint8([1ut, 5ut, 3ut, 7ut]);
select Knn::HammingDistance($byte_vector3, $byte_vector4);

--int8 vector
$int8_vector1 = Knn::ToBinaryStringInt8([1t, 2t, 3t]);
$int8_vector2 = Knn::ToBinaryStringInt8([4t, 5t, 6t]);
select Knn::HammingDistance($int8_vector1, $int8_vector2);

--int8 vector with some equal elements
$int8_vector3 = Knn::ToBinaryStringInt8([1t, 2t, 3t, 4t]);
$int8_vector4 = Knn::ToBinaryStringInt8([1t, 5t, 3t, 7t]);
select Knn::HammingDistance($int8_vector3, $int8_vector4);

--bit vector
$bitvector_positive = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64));
$bitvector_positive_double_size = Knn::ToBinaryStringBit(ListReplicate(1.0f, 128));
$bitvector_negative = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64));
$bitvector_negative_and_positive = Knn::ToBinaryStringBit(ListFromRange(-63.0f, 64.1f));
$bitvector_negative_and_positive_striped = Knn::ToBinaryStringBit(ListFlatten(ListReplicate([-1.0f, 1.0f], 32)));

select Knn::HammingDistance($bitvector_positive, $bitvector_negative);
select Knn::HammingDistance($bitvector_positive_double_size, $bitvector_negative_and_positive);
select Knn::HammingDistance($bitvector_positive, $bitvector_negative_and_positive_striped);

--bit vector -- with tail
$bitvector_pos_1_00 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 00));
$bitvector_pos_1_04 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 04));
$bitvector_pos_1_08 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 08));
$bitvector_pos_1_16 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 16));
$bitvector_pos_1_24 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 24));
$bitvector_pos_1_32 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 32));
$bitvector_pos_1_40 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 40));
$bitvector_pos_1_48 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 48));
$bitvector_pos_1_56 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 56));
$bitvector_pos_1_60 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 64 + 60));

$bitvector_neg_1_00 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 00));
$bitvector_neg_1_04 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 04));
$bitvector_neg_1_08 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 08));
$bitvector_neg_1_16 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 16));
$bitvector_neg_1_24 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 24));
$bitvector_neg_1_32 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 32));
$bitvector_neg_1_40 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 40));
$bitvector_neg_1_48 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 48));
$bitvector_neg_1_56 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 56));
$bitvector_neg_1_60 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 64 + 60));

select 64 + 00, Knn::HammingDistance($bitvector_pos_1_00, $bitvector_pos_1_00);
select 64 + 04, Knn::HammingDistance($bitvector_pos_1_04, $bitvector_pos_1_04);
select 64 + 08, Knn::HammingDistance($bitvector_pos_1_08, $bitvector_pos_1_08);
select 64 + 16, Knn::HammingDistance($bitvector_pos_1_16, $bitvector_pos_1_16);
select 64 + 24, Knn::HammingDistance($bitvector_pos_1_24, $bitvector_pos_1_24);
select 64 + 32, Knn::HammingDistance($bitvector_pos_1_32, $bitvector_pos_1_32);
select 64 + 40, Knn::HammingDistance($bitvector_pos_1_40, $bitvector_pos_1_40);
select 64 + 48, Knn::HammingDistance($bitvector_pos_1_48, $bitvector_pos_1_48);
select 64 + 56, Knn::HammingDistance($bitvector_pos_1_56, $bitvector_pos_1_56);
select 64 + 60, Knn::HammingDistance($bitvector_pos_1_60, $bitvector_pos_1_60);

select 64 + 00, Knn::HammingDistance($bitvector_neg_1_00, $bitvector_neg_1_00);
select 64 + 04, Knn::HammingDistance($bitvector_neg_1_04, $bitvector_neg_1_04);
select 64 + 08, Knn::HammingDistance($bitvector_neg_1_08, $bitvector_neg_1_08);
select 64 + 16, Knn::HammingDistance($bitvector_neg_1_16, $bitvector_neg_1_16);
select 64 + 24, Knn::HammingDistance($bitvector_neg_1_24, $bitvector_neg_1_24);
select 64 + 32, Knn::HammingDistance($bitvector_neg_1_32, $bitvector_neg_1_32);
select 64 + 40, Knn::HammingDistance($bitvector_neg_1_40, $bitvector_neg_1_40);
select 64 + 48, Knn::HammingDistance($bitvector_neg_1_48, $bitvector_neg_1_48);
select 64 + 56, Knn::HammingDistance($bitvector_neg_1_56, $bitvector_neg_1_56);
select 64 + 60, Knn::HammingDistance($bitvector_neg_1_60, $bitvector_neg_1_60);

select 64 + 00, Knn::HammingDistance($bitvector_pos_1_00, $bitvector_neg_1_00);
select 64 + 04, Knn::HammingDistance($bitvector_pos_1_04, $bitvector_neg_1_04);
select 64 + 08, Knn::HammingDistance($bitvector_pos_1_08, $bitvector_neg_1_08);
select 64 + 16, Knn::HammingDistance($bitvector_pos_1_16, $bitvector_neg_1_16);
select 64 + 24, Knn::HammingDistance($bitvector_pos_1_24, $bitvector_neg_1_24);
select 64 + 32, Knn::HammingDistance($bitvector_pos_1_32, $bitvector_neg_1_32);
select 64 + 40, Knn::HammingDistance($bitvector_pos_1_40, $bitvector_neg_1_40);
select 64 + 48, Knn::HammingDistance($bitvector_pos_1_48, $bitvector_neg_1_48);
select 64 + 56, Knn::HammingDistance($bitvector_pos_1_56, $bitvector_neg_1_56);
select 64 + 60, Knn::HammingDistance($bitvector_pos_1_60, $bitvector_neg_1_60);

select 64 + 00, Knn::HammingDistance($bitvector_neg_1_00, $bitvector_pos_1_00);
select 64 + 04, Knn::HammingDistance($bitvector_neg_1_04, $bitvector_pos_1_04);
select 64 + 08, Knn::HammingDistance($bitvector_neg_1_08, $bitvector_pos_1_08);
select 64 + 16, Knn::HammingDistance($bitvector_neg_1_16, $bitvector_pos_1_16);
select 64 + 24, Knn::HammingDistance($bitvector_neg_1_24, $bitvector_pos_1_24);
select 64 + 32, Knn::HammingDistance($bitvector_neg_1_32, $bitvector_pos_1_32);
select 64 + 40, Knn::HammingDistance($bitvector_neg_1_40, $bitvector_pos_1_40);
select 64 + 48, Knn::HammingDistance($bitvector_neg_1_48, $bitvector_pos_1_48);
select 64 + 56, Knn::HammingDistance($bitvector_neg_1_56, $bitvector_pos_1_56);
select 64 + 60, Knn::HammingDistance($bitvector_neg_1_60, $bitvector_pos_1_60);

--bit vector -- only tail
$bitvector_pos_00 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 00));
$bitvector_pos_04 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 04));
$bitvector_pos_08 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 08));
$bitvector_pos_16 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 16));
$bitvector_pos_24 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 24));
$bitvector_pos_32 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 32));
$bitvector_pos_40 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 40));
$bitvector_pos_48 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 48));
$bitvector_pos_56 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 56));
$bitvector_pos_60 = Knn::ToBinaryStringBit(ListReplicate(1.0f, 60));

$bitvector_neg_00 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 00));
$bitvector_neg_04 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 04));
$bitvector_neg_08 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 08));
$bitvector_neg_16 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 16));
$bitvector_neg_24 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 24));
$bitvector_neg_32 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 32));
$bitvector_neg_40 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 40));
$bitvector_neg_48 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 48));
$bitvector_neg_56 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 56));
$bitvector_neg_60 = Knn::ToBinaryStringBit(ListReplicate(-1.0f, 60));

select 00, Knn::HammingDistance($bitvector_pos_00, $bitvector_pos_00);
select 04, Knn::HammingDistance($bitvector_pos_04, $bitvector_pos_04);
select 08, Knn::HammingDistance($bitvector_pos_08, $bitvector_pos_08);
select 16, Knn::HammingDistance($bitvector_pos_16, $bitvector_pos_16);
select 24, Knn::HammingDistance($bitvector_pos_24, $bitvector_pos_24);
select 32, Knn::HammingDistance($bitvector_pos_32, $bitvector_pos_32);
select 40, Knn::HammingDistance($bitvector_pos_40, $bitvector_pos_40);
select 48, Knn::HammingDistance($bitvector_pos_48, $bitvector_pos_48);
select 56, Knn::HammingDistance($bitvector_pos_56, $bitvector_pos_56);
select 60, Knn::HammingDistance($bitvector_pos_60, $bitvector_pos_60);

select 00, Knn::HammingDistance($bitvector_neg_00, $bitvector_neg_00);
select 04, Knn::HammingDistance($bitvector_neg_04, $bitvector_neg_04);
select 08, Knn::HammingDistance($bitvector_neg_08, $bitvector_neg_08);
select 16, Knn::HammingDistance($bitvector_neg_16, $bitvector_neg_16);
select 24, Knn::HammingDistance($bitvector_neg_24, $bitvector_neg_24);
select 32, Knn::HammingDistance($bitvector_neg_32, $bitvector_neg_32);
select 40, Knn::HammingDistance($bitvector_neg_40, $bitvector_neg_40);
select 48, Knn::HammingDistance($bitvector_neg_48, $bitvector_neg_48);
select 56, Knn::HammingDistance($bitvector_neg_56, $bitvector_neg_56);
select 60, Knn::HammingDistance($bitvector_neg_60, $bitvector_neg_60);

select 00, Knn::HammingDistance($bitvector_pos_00, $bitvector_neg_00);
select 04, Knn::HammingDistance($bitvector_pos_04, $bitvector_neg_04);
select 08, Knn::HammingDistance($bitvector_pos_08, $bitvector_neg_08);
select 16, Knn::HammingDistance($bitvector_pos_16, $bitvector_neg_16);
select 24, Knn::HammingDistance($bitvector_pos_24, $bitvector_neg_24);
select 32, Knn::HammingDistance($bitvector_pos_32, $bitvector_neg_32);
select 40, Knn::HammingDistance($bitvector_pos_40, $bitvector_neg_40);
select 48, Knn::HammingDistance($bitvector_pos_48, $bitvector_neg_48);
select 56, Knn::HammingDistance($bitvector_pos_56, $bitvector_neg_56);
select 60, Knn::HammingDistance($bitvector_pos_60, $bitvector_neg_60);

select 00, Knn::HammingDistance($bitvector_neg_00, $bitvector_pos_00);
select 04, Knn::HammingDistance($bitvector_neg_04, $bitvector_pos_04);
select 08, Knn::HammingDistance($bitvector_neg_08, $bitvector_pos_08);
select 16, Knn::HammingDistance($bitvector_neg_16, $bitvector_pos_16);
select 24, Knn::HammingDistance($bitvector_neg_24, $bitvector_pos_24);
select 32, Knn::HammingDistance($bitvector_neg_32, $bitvector_pos_32);
select 40, Knn::HammingDistance($bitvector_neg_40, $bitvector_pos_40);
select 48, Knn::HammingDistance($bitvector_neg_48, $bitvector_pos_48);
select 56, Knn::HammingDistance($bitvector_neg_56, $bitvector_pos_56);
select 60, Knn::HammingDistance($bitvector_neg_60, $bitvector_pos_60);