diff --git a/ydb/library/yql/udfs/common/knn/knn-defines.h b/ydb/library/yql/udfs/common/knn/knn-defines.h index 647a4c9c1866..914f976b9e36 100644 --- a/ydb/library/yql/udfs/common/knn/knn-defines.h +++ b/ydb/library/yql/udfs/common/knn/knn-defines.h @@ -7,6 +7,7 @@ enum EFormat: ui8 { Uint8Vector = 2, // 1-byte per element, better than Int8 for positive-only Float Int8Vector = 3, // 1-byte per element BitVector = 10, // 1-bit per element + QBitVector = 11, // 1-bit per element, transposed layout for better cache locality }; template diff --git a/ydb/library/yql/udfs/common/knn/knn-distance.h b/ydb/library/yql/udfs/common/knn/knn-distance.h index 8c375b935616..5f94dde30fa9 100644 --- a/ydb/library/yql/udfs/common/knn/knn-distance.h +++ b/ydb/library/yql/udfs/common/knn/knn-distance.h @@ -11,6 +11,7 @@ #include #include +#include namespace { @@ -42,6 +43,16 @@ namespace { return {reinterpret_cast(buf), len}; } + inline TBitArray GetQBitArray(const TStringBuf& str) { + if (Y_UNLIKELY(str.Size() < HeaderLen + sizeof(ui32))) { + return {}; + } + const char* buf = str.Data(); + ui32 elementCount; + std::memcpy(&elementCount, buf + str.Size() - HeaderLen - sizeof(ui32), sizeof(ui32)); + return {reinterpret_cast(buf), elementCount}; + } + } // namespace template @@ -118,6 +129,13 @@ class KnnDistance { return VectorFuncImpl(v1, v2, bitLen1, bitLen2, std::forward(func)); } + template + static auto QBitVectorFunc(const TStringBuf& str1, const TStringBuf& str2, Func&& func) { + auto [v1, bitLen1] = GetQBitArray(str1); + auto [v2, bitLen2] = GetQBitArray(str2); + return VectorFuncImpl(v1, v2, bitLen1, bitLen2, std::forward(func)); + } + public: static TDistanceResult ManhattanDistance(const TStringBuf& str1, const TStringBuf& str2) { const ui8 format1 = str1.Data()[str1.Size() - HeaderLen]; @@ -145,6 +163,26 @@ class KnnDistance { BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) { ret += std::popcount(d1 ^ d2); }); + return ret; + }); + case EFormat::QBitVector: + return QBitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 elementCount) { + ui64 ret = 0; + const size_t fullBlocks = elementCount / 64; + const size_t remainder = elementCount % 64; + + // Process full blocks + for (size_t i = 0; i < fullBlocks; ++i) { + ret += std::popcount(v1[i] ^ v2[i]); + } + + // Process partial block with mask + if (remainder > 0) { + // Defensive: remainder is mathematically [0, 63], but protect against UB + const ui64 mask = (remainder >= 64) ? ~ui64{0} : ((ui64{1} << remainder) - 1); + ret += std::popcount((v1[fullBlocks] ^ v2[fullBlocks]) & mask); + } + return ret; }); default: @@ -178,6 +216,26 @@ class KnnDistance { BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) { ret += std::popcount(d1 ^ d2); }); + return NPrivate::NL2Distance::L2DistanceSqrt(ret); + }); + case EFormat::QBitVector: + return QBitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 elementCount) { + ui64 ret = 0; + const size_t fullBlocks = elementCount / 64; + const size_t remainder = elementCount % 64; + + // Process full blocks + for (size_t i = 0; i < fullBlocks; ++i) { + ret += std::popcount(v1[i] ^ v2[i]); + } + + // Process partial block with mask + if (remainder > 0) { + // Defensive: remainder is mathematically [0, 63], but protect against UB + const ui64 mask = (remainder >= 64) ? ~ui64{0} : ((ui64{1} << remainder) - 1); + ret += std::popcount((v1[fullBlocks] ^ v2[fullBlocks]) & mask); + } + return NPrivate::NL2Distance::L2DistanceSqrt(ret); }); default: @@ -211,6 +269,26 @@ class KnnDistance { BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) { ret += std::popcount(d1 & d2); }); + return ret; + }); + case EFormat::QBitVector: + return QBitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 elementCount) { + ui64 ret = 0; + const size_t fullBlocks = elementCount / 64; + const size_t remainder = elementCount % 64; + + // Process full blocks + for (size_t i = 0; i < fullBlocks; ++i) { + ret += std::popcount(v1[i] & v2[i]); + } + + // Process partial block with mask + if (remainder > 0) { + // Defensive: remainder is mathematically [0, 63], but protect against UB + const ui64 mask = (remainder >= 64) ? ~ui64{0} : ((ui64{1} << remainder) - 1); + ret += std::popcount((v1[fullBlocks] & v2[fullBlocks]) & mask); + } + return ret; }); default: @@ -263,6 +341,32 @@ class KnnDistance { rr += std::popcount(d2); lr += std::popcount(d1 & d2); }); + return compute(ll, lr, rr); + }); + case EFormat::QBitVector: + return QBitVectorFunc(str1, str2, [&](const ui64* v1, const ui64* v2, ui64 elementCount) { + ui64 ll = 0; + ui64 rr = 0; + ui64 lr = 0; + const size_t fullBlocks = elementCount / 64; + const size_t remainder = elementCount % 64; + + // Process full blocks + for (size_t i = 0; i < fullBlocks; ++i) { + ll += std::popcount(v1[i]); + rr += std::popcount(v2[i]); + lr += std::popcount(v1[i] & v2[i]); + } + + // Process partial block with mask + if (remainder > 0) { + // Defensive: remainder is mathematically [0, 63], but protect against UB + const ui64 mask = (remainder >= 64) ? ~ui64{0} : ((ui64{1} << remainder) - 1); + ll += std::popcount(v1[fullBlocks] & mask); + rr += std::popcount(v2[fullBlocks] & mask); + lr += std::popcount((v1[fullBlocks] & v2[fullBlocks]) & mask); + } + return compute(ll, lr, rr); }); default: diff --git a/ydb/library/yql/udfs/common/knn/knn-serializer-shared.h b/ydb/library/yql/udfs/common/knn/knn-serializer-shared.h index 3d8b2732b816..ff245db93572 100644 --- a/ydb/library/yql/udfs/common/knn/knn-serializer-shared.h +++ b/ydb/library/yql/udfs/common/knn/knn-serializer-shared.h @@ -7,7 +7,9 @@ #include #include +#include #include +#include namespace NKnnVectorSerialization { @@ -212,12 +214,137 @@ namespace NKnnVectorSerialization { } }; + // QBit format: stores bits in transposed layout for better cache locality + // Elements are grouped into blocks of 64, and bits are transposed within each block + // This improves performance when computing distances between multiple vectors + struct TQBit {}; + + template <> + class TSerializer { + private: + IOutputStream* OutStream_ = nullptr; + std::vector Block_; + size_t ElementCount_ = 0; + static constexpr size_t BlockSize = 64; + + public: + TSerializer() = delete; + TSerializer(IOutputStream* outStream) + : OutStream_(outStream) + { + Block_.reserve(BlockSize); + } + ~TSerializer() { + if (OutStream_ != nullptr) { + Finish(); + } + } + + template + void HandleElement(const TFrom& from) { + Y_ENSURE(OutStream_); + + // Store the bit value (1 for positive, 0 for negative or zero) + ui64 bit = (from > 0) ? 1 : 0; + Block_.push_back(bit); + ElementCount_++; + + // When we have a complete block, write it + if (Block_.size() == BlockSize) { + WriteTransposedBlock(); + } + } + + void Finish() { + Y_ENSURE(OutStream_); + Y_DEFER { + OutStream_ = nullptr; + }; + + // Write any remaining elements in the last block + if (!Block_.empty()) { + WriteTransposedBlock(); + } + + // Write metadata: element count (4 bytes) and format (1 byte) + ui32 count = static_cast(ElementCount_); + OutStream_->Write(&count, sizeof(ui32)); + const auto format = EFormat::QBitVector; + OutStream_->Write(&format, HeaderLen); + } + + private: + void WriteTransposedBlock() { + // Pack bits into a single ui64 block + // Bit i in the ui64 corresponds to element i in the block + // This layout enables efficient SIMD operations on packed bits + ui64 transposed = 0; + for (size_t i = 0; i < Block_.size(); ++i) { + if (Block_[i]) { + transposed |= (ui64{1} << i); + } + } + OutStream_->Write(&transposed, sizeof(ui64)); + Block_.clear(); + } + }; + + template <> + class TDeserializer { + private: + const TStringBuf Data_; + ui32 ElementCount_; + + public: + TDeserializer() = delete; + TDeserializer(const TStringBuf data) + : Data_(data) + { + Y_ENSURE(data.size() >= HeaderLen + sizeof(ui32)); + Y_ENSURE(data[data.size() - HeaderLen] == EFormat::QBitVector); + + // Read element count from the metadata + std::memcpy(&ElementCount_, data.data() + data.size() - HeaderLen - sizeof(ui32), sizeof(ui32)); + } + + size_t GetElementCount() const { + return ElementCount_; + } + + void DoDeserialize(std::function&& elementHandler) const { + const size_t fullBlocks = ElementCount_ / 64; + const size_t remainder = ElementCount_ % 64; + const auto ptr = reinterpret_cast(Data_.data()); + + // Process full blocks + for (size_t blockIdx = 0; blockIdx < fullBlocks; ++blockIdx) { + ui64 block = ptr[blockIdx]; + for (size_t i = 0; i < 64; ++i) { + bool bit = (block & (ui64{1} << i)) != 0; + elementHandler(bit); + } + } + + // Process remaining elements in the last block + if (remainder > 0) { + ui64 block = ptr[fullBlocks]; + for (size_t i = 0; i < remainder; ++i) { + bool bit = (block & (ui64{1} << i)) != 0; + elementHandler(bit); + } + } + } + }; + template inline size_t GetBufferSize(const size_t elementCount) { if constexpr (std::is_same_v) { // We expect byte length of the result is (bit-length / 8 + bit-length % 8 != 0) + bits-count (1 byte) + HeaderLen // First part can be optimized to (bit-length + 7) / 8 return (elementCount + 7) / 8 + 1 + HeaderLen; + } else if constexpr (std::is_same_v) { + // QBit format: (elementCount + 63) / 64 blocks of ui64 + element count (4 bytes) + HeaderLen + return ((elementCount + 63) / 64) * sizeof(ui64) + sizeof(ui32) + HeaderLen; } else { return elementCount * sizeof(TTo) + HeaderLen; } diff --git a/ydb/library/yql/udfs/common/knn/knn-serializer.h b/ydb/library/yql/udfs/common/knn/knn-serializer.h index cdcc5ec7782a..2f97095cbe9f 100644 --- a/ydb/library/yql/udfs/common/knn/knn-serializer.h +++ b/ydb/library/yql/udfs/common/knn/knn-serializer.h @@ -75,6 +75,8 @@ class TKnnSerializerFacade { return TKnnVectorSerializer::Deserialize(valueBuilder, str); case EFormat::BitVector: return TKnnVectorSerializer::Deserialize(valueBuilder, str); + case EFormat::QBitVector: + return TKnnVectorSerializer::Deserialize(valueBuilder, str); default: return {}; } diff --git a/ydb/library/yql/udfs/common/knn/knn.cpp b/ydb/library/yql/udfs/common/knn/knn.cpp index 2810a3006c11..78b3be79d900 100644 --- a/ydb/library/yql/udfs/common/knn/knn.cpp +++ b/ydb/library/yql/udfs/common/knn/knn.cpp @@ -31,6 +31,8 @@ static constexpr const char TagUint8Vector[] = "Uint8Vector"; using TUint8Vector = TTagged; static constexpr const char TagBitVector[] = "BitVector"; using TBitVector = TTagged; +static constexpr const char TagQBitVector[] = "QBitVector"; +using TQBitVector = TTagged; SIMPLE_STRICT_UDF(TToBinaryStringFloat, TFloatVector(TAutoMap>)) { return TKnnVectorSerializer::Serialize(valueBuilder, args[0]); @@ -200,6 +202,105 @@ class TToBinaryStringBit: public TMultiSignatureBase { }; }; +template +class TToBinaryStringQBitImpl: public TMultiSignatureBase> { +public: + using TMultiSignatureBase>::TMultiSignatureBase; + + static const TStringRef& Name() { + static auto name = TStringRef::Of("ToBinaryStringQBit"); + return name; + } + + TUnboxedValue RunImpl(const IValueBuilder* valueBuilder, const TUnboxedValuePod* args) const { + return TKnnVectorSerializer::Serialize(valueBuilder, args[0]); + } +}; + +class TToBinaryStringQBit: public TMultiSignatureBase { +public: + using TMultiSignatureBase::TMultiSignatureBase; + + static const TStringRef& Name() { + static auto name = TStringRef::Of("ToBinaryStringQBit"); + return name; + } + + static bool DeclareSignature(const TStringRef& name, TType* userType, IFunctionTypeInfoBuilder& builder, bool typesOnly) { + if (Name() != name) { + return false; + } + + auto typeInfoHelper = builder.TypeInfoHelper(); + Y_ENSURE(userType); + TTupleTypeInspector tuple{*typeInfoHelper, userType}; + Y_ENSURE(tuple); + Y_ENSURE(tuple.GetElementsCount() > 0); + TTupleTypeInspector argsTuple{*typeInfoHelper, tuple.GetElementType(0)}; + Y_ENSURE(argsTuple); + if (argsTuple.GetElementsCount() != 1) { + builder.SetError("One argument is expected"); + return true; + } + + auto argType = argsTuple.GetElementType(0); + if (const auto kind = typeInfoHelper->GetTypeKind(argType); kind == ETypeKind::Null) { + argType = builder.SimpleType>(); + } + if (const TOptionalTypeInspector optional{*typeInfoHelper, argType}; optional) { + argType = optional.GetItemType(); + } + auto type = EType::None; + if (const TListTypeInspector list{*typeInfoHelper, argType}; list) { + if (const TDataTypeInspector data{*typeInfoHelper, list.GetItemType()}; data) { + if (data.GetTypeId() == TDataType::Id) { + type = EType::Double; + } else if (data.GetTypeId() == TDataType::Id) { + type = EType::Float; + } else if (data.GetTypeId() == TDataType::Id) { + type = EType::Uint8; + } else if (data.GetTypeId() == TDataType::Id) { + type = EType::Int8; + } + } + } + if (type == EType::None) { + TStringBuilder sb; + sb << "'List' is expected but got '"; + TTypePrinter(*typeInfoHelper, argsTuple.GetElementType(0)).Out(sb.Out); + sb << "'"; + builder.SetError(std::move(sb)); + return true; + } + + builder.UserType(userType); + builder.Args(1)->Add(argType).Flags(ICallablePayload::TArgumentFlags::AutoMap); + builder.Returns().IsStrict(); + + if (!typesOnly) { + if (type == EType::Double) { + builder.Implementation(new TToBinaryStringQBitImpl(builder)); + } else if (type == EType::Float) { + builder.Implementation(new TToBinaryStringQBitImpl(builder)); + } else if (type == EType::Uint8) { + builder.Implementation(new TToBinaryStringQBitImpl(builder)); + } else if (type == EType::Int8) { + builder.Implementation(new TToBinaryStringQBitImpl(builder)); + } + } + return true; + } + +private: + enum class EType { + None, + Double, + Float, + Uint8, + Int8, + }; +}; + class TFloatFromBinaryString: public TMultiSignatureBase { public: using TMultiSignatureBase::TMultiSignatureBase; @@ -232,9 +333,9 @@ class TFloatFromBinaryString: public TMultiSignatureBase auto argType = argsTuple.GetElementType(0); auto argTag = GetArg(*typeInfoHelper, argType, builder); - if (!ValidTag(argTag, {TagStoredVector, TagFloatVector, TagInt8Vector, TagUint8Vector, TagBitVector})) { + if (!ValidTag(argTag, {TagStoredVector, TagFloatVector, TagInt8Vector, TagUint8Vector, TagBitVector, TagQBitVector})) { TStringBuilder sb; - sb << "A result from 'ToBinaryString[Float|Int8|Uint8|Bit]' is expected as an argument but got '"; + sb << "A result from 'ToBinaryString[Float|Int8|Uint8|Bit|QBit]' is expected as an argument but got '"; TTypePrinter(*typeInfoHelper, argsTuple.GetElementType(0)).Out(sb.Out); sb << "'"; builder.SetError(std::move(sb)); @@ -243,7 +344,7 @@ class TFloatFromBinaryString: public TMultiSignatureBase builder.UserType(userType); builder.Args(1)->Add(argType).Flags(ICallablePayload::TArgumentFlags::AutoMap); - if (ValidTag(argTag, {TagFloatVector, TagInt8Vector, TagUint8Vector, TagBitVector}) && argType == argsTuple.GetElementType(0)) { + if (ValidTag(argTag, {TagFloatVector, TagInt8Vector, TagUint8Vector, TagBitVector, TagQBitVector}) && argType == argsTuple.GetElementType(0)) { builder.Returns>().IsStrict(); } else { builder.Returns>>().IsStrict(); @@ -285,10 +386,10 @@ class TDistanceBase: public TMultiSignatureBase { auto arg1Type = argsTuple.GetElementType(1); auto arg1Tag = Base::GetArg(*typeInfoHelper, arg1Type, builder); - if (!Base::ValidTag(arg0Tag, {TagStoredVector, TagFloatVector, TagInt8Vector, TagUint8Vector, TagBitVector}) || - !Base::ValidTag(arg1Tag, {TagStoredVector, TagFloatVector, TagInt8Vector, TagUint8Vector, TagBitVector})) { + if (!Base::ValidTag(arg0Tag, {TagStoredVector, TagFloatVector, TagInt8Vector, TagUint8Vector, TagBitVector, TagQBitVector}) || + !Base::ValidTag(arg1Tag, {TagStoredVector, TagFloatVector, TagInt8Vector, TagUint8Vector, TagBitVector, TagQBitVector})) { TStringBuilder sb; - sb << "Both arguments are expected to be results from 'ToBinaryString[Float|Int8|Uint8]' but got '"; + sb << "Both arguments are expected to be results from 'ToBinaryString[Float|Int8|Uint8|Bit|QBit]' but got '"; TTypePrinter(*typeInfoHelper, argsTuple.GetElementType(0)).Out(sb.Out); sb << "' and '"; TTypePrinter(*typeInfoHelper, argsTuple.GetElementType(1)).Out(sb.Out); @@ -415,6 +516,7 @@ SIMPLE_MODULE(TKnnModule, TToBinaryStringInt8, TToBinaryStringUint8, TToBinaryStringBit, + TToBinaryStringQBit, TFloatFromBinaryString, TInnerProductSimilarity, TCosineSimilarity, diff --git a/ydb/library/yql/udfs/common/knn/test/cases/QBitSerialization.sql b/ydb/library/yql/udfs/common/knn/test/cases/QBitSerialization.sql new file mode 100644 index 000000000000..2da8448ae396 --- /dev/null +++ b/ydb/library/yql/udfs/common/knn/test/cases/QBitSerialization.sql @@ -0,0 +1,75 @@ +-- QBit serialization test +$vector_pos = [1.0f, 2.0f, 3.0f, 4.0f]; +$vector_neg = [-1.0f, -2.0f, -3.0f, -4.0f]; +$vector_mixed = [1.0f, -2.0f, 3.0f, -4.0f]; + +$qbitvector_pos = Knn::ToBinaryStringQBit($vector_pos); +$qbitvector_neg = Knn::ToBinaryStringQBit($vector_neg); +$qbitvector_mixed = Knn::ToBinaryStringQBit($vector_mixed); + +select $qbitvector_pos; +select $qbitvector_neg; +select $qbitvector_mixed; + +-- deserialization +$deserialized_pos = Knn::FloatFromBinaryString($qbitvector_pos); +$deserialized_neg = Knn::FloatFromBinaryString($qbitvector_neg); +$deserialized_mixed = Knn::FloatFromBinaryString($qbitvector_mixed); + +select $deserialized_pos; +select $deserialized_neg; +select $deserialized_mixed; + +-- test different types +$vector_d = Cast([1, -1, 1, -1] AS List); +$vector_f = Cast([1, -1, 1, -1] AS List); +$vector_u8 = Cast([1, 0, 1, 0] AS List); +$vector_i8 = Cast([1, -1, 1, -1] AS List); + +select Knn::ToBinaryStringQBit($vector_d); +select Knn::ToBinaryStringQBit($vector_f); +select Knn::ToBinaryStringQBit($vector_u8); +select Knn::ToBinaryStringQBit($vector_i8); + +-- test with 64 elements (full block) +$vector_64_pos = ListReplicate(1.0f, 64); +$vector_64_neg = ListReplicate(-1.0f, 64); + +$qbitvector_64_pos = Knn::ToBinaryStringQBit($vector_64_pos); +$qbitvector_64_neg = Knn::ToBinaryStringQBit($vector_64_neg); + +select Len(Untag($qbitvector_64_pos, "QBitVector")); +select Len(Untag($qbitvector_64_neg, "QBitVector")); + +$deserialized_64_pos = Knn::FloatFromBinaryString($qbitvector_64_pos); +$deserialized_64_neg = Knn::FloatFromBinaryString($qbitvector_64_neg); + +select $deserialized_64_pos; +select $deserialized_64_neg; + +-- test with 128 elements (two full blocks) +$vector_128_pos = ListReplicate(1.0f, 128); +$qbitvector_128_pos = Knn::ToBinaryStringQBit($vector_128_pos); +select Len(Untag($qbitvector_128_pos, "QBitVector")); +$deserialized_128_pos = Knn::FloatFromBinaryString($qbitvector_128_pos); +select ListLength($deserialized_128_pos); + +-- test with 65 elements (one full block + 1) +$vector_65_pos = ListReplicate(1.0f, 65); +$qbitvector_65_pos = Knn::ToBinaryStringQBit($vector_65_pos); +select Len(Untag($qbitvector_65_pos, "QBitVector")); +$deserialized_65_pos = Knn::FloatFromBinaryString($qbitvector_65_pos); +select ListLength($deserialized_65_pos); + +-- Distance functions with QBit +$vec1 = [1.0f, 1.0f, 1.0f, 1.0f]; +$vec2 = [1.0f, 1.0f, -1.0f, -1.0f]; + +$qbit1 = Knn::ToBinaryStringQBit($vec1); +$qbit2 = Knn::ToBinaryStringQBit($vec2); + +select Knn::InnerProductSimilarity($qbit1, $qbit2); +select Knn::CosineSimilarity($qbit1, $qbit2); +select Knn::CosineDistance($qbit1, $qbit2); +select Knn::ManhattanDistance($qbit1, $qbit2); +select Knn::EuclideanDistance($qbit1, $qbit2);