Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ydb/library/yql/udfs/common/knn/knn-defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ enum EFormat: ui8 {
Uint8Vector = 2, // 1-byte per element, better than Int8 for positive-only Float
Int8Vector = 3, // 1-byte per element
BitVector = 10, // 1-bit per element
QBitVector = 11, // 1-bit per element, transposed layout for better cache locality
};

template <typename T>
Expand Down
104 changes: 104 additions & 0 deletions ydb/library/yql/udfs/common/knn/knn-distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <util/stream/format.h>

#include <bit>
#include <cstring>

namespace {

Expand Down Expand Up @@ -42,6 +43,16 @@ namespace {
return {reinterpret_cast<const ui64*>(buf), len};
}

inline TBitArray GetQBitArray(const TStringBuf& str) {
if (Y_UNLIKELY(str.Size() < HeaderLen + sizeof(ui32))) {
return {};
}
const char* buf = str.Data();
ui32 elementCount;
std::memcpy(&elementCount, buf + str.Size() - HeaderLen - sizeof(ui32), sizeof(ui32));
return {reinterpret_cast<const ui64*>(buf), elementCount};
}

} // namespace

template <typename TRes = float>
Expand Down Expand Up @@ -118,6 +129,13 @@ class KnnDistance {
return VectorFuncImpl(v1, v2, bitLen1, bitLen2, std::forward<Func>(func));
}

template <typename Func>
static auto QBitVectorFunc(const TStringBuf& str1, const TStringBuf& str2, Func&& func) {
auto [v1, bitLen1] = GetQBitArray(str1);
auto [v2, bitLen2] = GetQBitArray(str2);
return VectorFuncImpl(v1, v2, bitLen1, bitLen2, std::forward<Func>(func));
}

public:
static TDistanceResult ManhattanDistance(const TStringBuf& str1, const TStringBuf& str2) {
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
Expand Down Expand Up @@ -145,6 +163,26 @@ class KnnDistance {
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
ret += std::popcount(d1 ^ d2);
});
return ret;
});
case EFormat::QBitVector:
return QBitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 elementCount) {
ui64 ret = 0;
const size_t fullBlocks = elementCount / 64;
const size_t remainder = elementCount % 64;

// Process full blocks
for (size_t i = 0; i < fullBlocks; ++i) {
ret += std::popcount(v1[i] ^ v2[i]);
}

// Process partial block with mask
if (remainder > 0) {
// Defensive: remainder is mathematically [0, 63], but protect against UB
const ui64 mask = (remainder >= 64) ? ~ui64{0} : ((ui64{1} << remainder) - 1);
ret += std::popcount((v1[fullBlocks] ^ v2[fullBlocks]) & mask);
}

return ret;
});
default:
Expand Down Expand Up @@ -178,6 +216,26 @@ class KnnDistance {
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
ret += std::popcount(d1 ^ d2);
});
return NPrivate::NL2Distance::L2DistanceSqrt(ret);
});
case EFormat::QBitVector:
return QBitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 elementCount) {
ui64 ret = 0;
const size_t fullBlocks = elementCount / 64;
const size_t remainder = elementCount % 64;

// Process full blocks
for (size_t i = 0; i < fullBlocks; ++i) {
ret += std::popcount(v1[i] ^ v2[i]);
}

// Process partial block with mask
if (remainder > 0) {
// Defensive: remainder is mathematically [0, 63], but protect against UB
const ui64 mask = (remainder >= 64) ? ~ui64{0} : ((ui64{1} << remainder) - 1);
ret += std::popcount((v1[fullBlocks] ^ v2[fullBlocks]) & mask);
}

return NPrivate::NL2Distance::L2DistanceSqrt(ret);
});
default:
Expand Down Expand Up @@ -211,6 +269,26 @@ class KnnDistance {
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
ret += std::popcount(d1 & d2);
});
return ret;
});
case EFormat::QBitVector:
return QBitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 elementCount) {
ui64 ret = 0;
const size_t fullBlocks = elementCount / 64;
const size_t remainder = elementCount % 64;

// Process full blocks
for (size_t i = 0; i < fullBlocks; ++i) {
ret += std::popcount(v1[i] & v2[i]);
}

// Process partial block with mask
if (remainder > 0) {
// Defensive: remainder is mathematically [0, 63], but protect against UB
const ui64 mask = (remainder >= 64) ? ~ui64{0} : ((ui64{1} << remainder) - 1);
ret += std::popcount((v1[fullBlocks] & v2[fullBlocks]) & mask);
}

return ret;
});
default:
Expand Down Expand Up @@ -263,6 +341,32 @@ class KnnDistance {
rr += std::popcount(d2);
lr += std::popcount(d1 & d2);
});
return compute(ll, lr, rr);
});
case EFormat::QBitVector:
return QBitVectorFunc(str1, str2, [&](const ui64* v1, const ui64* v2, ui64 elementCount) {
ui64 ll = 0;
ui64 rr = 0;
ui64 lr = 0;
const size_t fullBlocks = elementCount / 64;
const size_t remainder = elementCount % 64;

// Process full blocks
for (size_t i = 0; i < fullBlocks; ++i) {
ll += std::popcount(v1[i]);
rr += std::popcount(v2[i]);
lr += std::popcount(v1[i] & v2[i]);
}

// Process partial block with mask
if (remainder > 0) {
// Defensive: remainder is mathematically [0, 63], but protect against UB
const ui64 mask = (remainder >= 64) ? ~ui64{0} : ((ui64{1} << remainder) - 1);
ll += std::popcount(v1[fullBlocks] & mask);
rr += std::popcount(v2[fullBlocks] & mask);
lr += std::popcount((v1[fullBlocks] & v2[fullBlocks]) & mask);
}

return compute(ll, lr, rr);
});
default:
Expand Down
127 changes: 127 additions & 0 deletions ydb/library/yql/udfs/common/knn/knn-serializer-shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include <util/generic/yexception.h>
#include <util/stream/output.h>

#include <cstring>
#include <functional>
#include <vector>

namespace NKnnVectorSerialization {

Expand Down Expand Up @@ -212,12 +214,137 @@ namespace NKnnVectorSerialization {
}
};

// QBit format: stores bits in transposed layout for better cache locality
// Elements are grouped into blocks of 64, and bits are transposed within each block
// This improves performance when computing distances between multiple vectors
struct TQBit {};

template <>
class TSerializer<TQBit> {
private:
IOutputStream* OutStream_ = nullptr;
std::vector<ui64> Block_;
size_t ElementCount_ = 0;
static constexpr size_t BlockSize = 64;

public:
TSerializer() = delete;
TSerializer(IOutputStream* outStream)
: OutStream_(outStream)
{
Block_.reserve(BlockSize);
}
~TSerializer() {
if (OutStream_ != nullptr) {
Finish();
}
}

template <typename TFrom>
void HandleElement(const TFrom& from) {
Y_ENSURE(OutStream_);

// Store the bit value (1 for positive, 0 for negative or zero)
ui64 bit = (from > 0) ? 1 : 0;
Block_.push_back(bit);
ElementCount_++;

// When we have a complete block, write it
if (Block_.size() == BlockSize) {
WriteTransposedBlock();
}
}

void Finish() {
Y_ENSURE(OutStream_);
Y_DEFER {
OutStream_ = nullptr;
};

// Write any remaining elements in the last block
if (!Block_.empty()) {
WriteTransposedBlock();
}

// Write metadata: element count (4 bytes) and format (1 byte)
ui32 count = static_cast<ui32>(ElementCount_);
OutStream_->Write(&count, sizeof(ui32));
const auto format = EFormat::QBitVector;
OutStream_->Write(&format, HeaderLen);
}

private:
void WriteTransposedBlock() {
// Pack bits into a single ui64 block
// Bit i in the ui64 corresponds to element i in the block
// This layout enables efficient SIMD operations on packed bits
ui64 transposed = 0;
for (size_t i = 0; i < Block_.size(); ++i) {
if (Block_[i]) {
transposed |= (ui64{1} << i);
}
}
OutStream_->Write(&transposed, sizeof(ui64));
Block_.clear();
}
};

template <>
class TDeserializer<TQBit> {
private:
const TStringBuf Data_;
ui32 ElementCount_;

public:
TDeserializer() = delete;
TDeserializer(const TStringBuf data)
: Data_(data)
{
Y_ENSURE(data.size() >= HeaderLen + sizeof(ui32));
Y_ENSURE(data[data.size() - HeaderLen] == EFormat::QBitVector);

// Read element count from the metadata
std::memcpy(&ElementCount_, data.data() + data.size() - HeaderLen - sizeof(ui32), sizeof(ui32));
}

size_t GetElementCount() const {
return ElementCount_;
}

void DoDeserialize(std::function<void(const bool&)>&& elementHandler) const {
const size_t fullBlocks = ElementCount_ / 64;
const size_t remainder = ElementCount_ % 64;
const auto ptr = reinterpret_cast<const ui64*>(Data_.data());

// Process full blocks
for (size_t blockIdx = 0; blockIdx < fullBlocks; ++blockIdx) {
ui64 block = ptr[blockIdx];
for (size_t i = 0; i < 64; ++i) {
bool bit = (block & (ui64{1} << i)) != 0;
elementHandler(bit);
}
}

// Process remaining elements in the last block
if (remainder > 0) {
ui64 block = ptr[fullBlocks];
for (size_t i = 0; i < remainder; ++i) {
bool bit = (block & (ui64{1} << i)) != 0;
elementHandler(bit);
}
}
}
};

template <typename TTo>
inline size_t GetBufferSize(const size_t elementCount) {
if constexpr (std::is_same_v<TTo, bool>) {
// We expect byte length of the result is (bit-length / 8 + bit-length % 8 != 0) + bits-count (1 byte) + HeaderLen
// First part can be optimized to (bit-length + 7) / 8
return (elementCount + 7) / 8 + 1 + HeaderLen;
} else if constexpr (std::is_same_v<TTo, TQBit>) {
// QBit format: (elementCount + 63) / 64 blocks of ui64 + element count (4 bytes) + HeaderLen
return ((elementCount + 63) / 64) * sizeof(ui64) + sizeof(ui32) + HeaderLen;
} else {
return elementCount * sizeof(TTo) + HeaderLen;
}
Expand Down
2 changes: 2 additions & 0 deletions ydb/library/yql/udfs/common/knn/knn-serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class TKnnSerializerFacade {
return TKnnVectorSerializer<ui8, float>::Deserialize(valueBuilder, str);
case EFormat::BitVector:
return TKnnVectorSerializer<bool, float>::Deserialize(valueBuilder, str);
case EFormat::QBitVector:
return TKnnVectorSerializer<NKnnVectorSerialization::TQBit, float>::Deserialize(valueBuilder, str);
default:
return {};
}
Expand Down
Loading