Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 70 additions & 7 deletions ydb/core/kqp/ut/indexes/kqp_indexes_prefixed_vector_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) {
DoPositiveQueriesPrefixedVectorIndexOrderBy(session, "CosineSimilarity", "DESC", covered);
}

TSession DoCreateTableForPrefixedVectorIndex(TTableClient& db, bool nullable) {
TSession DoCreateTableForPrefixedVectorIndex(TTableClient& db, bool nullable, bool suffixPk = false) {
auto session = db.CreateSession().GetValueSync().GetSession();

{
Expand All @@ -191,14 +191,25 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) {
.AddNonNullableColumn("emb", EPrimitiveType::String)
.AddNonNullableColumn("data", EPrimitiveType::String);
}
tableBuilder.SetPrimaryKeyColumns({"pk"});
if (suffixPk) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

тест что ты добавил окей пусть будет, но может ещё простеньких тестов сделать несколько штук как в тикете?

#18196

то есть CREATE TABLE + индекс и проверить содержимое всех трех таблиц и их схему

(и не для префиксного тоже бы)

без сплитов, просто базовую функциональность

их раньше ещё читать нельзя было, а теперь можно

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

короче говоря посмотрел, вроде дело хорошее, но выглядит как будто надо это отдельно делать, тут-то именно багфикс. а остальные кейсы, где не префиксный и т.п. - вроде и так другими тестами проверяются

tableBuilder.SetPrimaryKeyColumns({"pk", "user"});
} else {
tableBuilder.SetPrimaryKeyColumns({"pk"});
}
tableBuilder.BeginPartitioningSettings()
.SetMinPartitionsCount(3)
.EndPartitioningSettings();
auto partitions = TExplicitPartitions{}
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(40).EndTuple().Build())
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(60).EndTuple().Build());
tableBuilder.SetPartitionAtKeys(partitions);
.EndPartitioningSettings();
if (suffixPk) {
auto partitions = TExplicitPartitions{}
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(40).AddElement().OptionalString("").EndTuple().Build())
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(60).AddElement().OptionalString("").EndTuple().Build());
tableBuilder.SetPartitionAtKeys(partitions);
} else {
auto partitions = TExplicitPartitions{}
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(40).EndTuple().Build())
.AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(60).EndTuple().Build());
tableBuilder.SetPartitionAtKeys(partitions);
}
auto result = session.CreateTable("/Root/TestTable", tableBuilder.Build()).ExtractValueSync();
UNIT_ASSERT_VALUES_EQUAL(result.IsTransportError(), false);
UNIT_ASSERT_VALUES_EQUAL_C(result.GetStatus(), EStatus::SUCCESS, result.GetIssues().ToString());
Expand Down Expand Up @@ -488,6 +499,58 @@ Y_UNIT_TEST_SUITE(KqpPrefixedVectorIndexes) {
DoPositiveQueriesPrefixedVectorIndexOrderByCosine(session, true /*covered*/);
}

Y_UNIT_TEST_QUAD(CosineDistanceWithPkPrefix, Nullable, Covered) {
NKikimrConfig::TFeatureFlags featureFlags;
featureFlags.SetEnableVectorIndex(true);
auto setting = NKikimrKqp::TKqpSetting();
auto serverSettings = TKikimrSettings()
.SetFeatureFlags(featureFlags)
.SetKqpSettings({setting});

TKikimrRunner kikimr(serverSettings);
kikimr.GetTestServer().GetRuntime()->SetLogPriority(NKikimrServices::BUILD_INDEX, NActors::NLog::PRI_TRACE);
kikimr.GetTestServer().GetRuntime()->SetLogPriority(NKikimrServices::FLAT_TX_SCHEMESHARD, NActors::NLog::PRI_TRACE);

auto db = kikimr.GetTableClient();

auto session = DoCreateTableForPrefixedVectorIndex(db, Nullable, true);
{
const TString createIndex(Q_(Sprintf(R"(
ALTER TABLE `/Root/TestTable`
ADD INDEX index
GLOBAL USING vector_kmeans_tree
ON (user, emb) %s
WITH (distance=cosine, vector_type="uint8", vector_dimension=2, levels=2, clusters=2);
)", (Covered ? "COVER (emb, data)" : ""))));

auto result = session.ExecuteSchemeQuery(createIndex)
.ExtractValueSync();

UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString());
}
{
auto result = session.DescribeTable("/Root/TestTable").ExtractValueSync();
UNIT_ASSERT_VALUES_EQUAL(result.GetStatus(), NYdb::EStatus::SUCCESS);
const auto& indexes = result.GetTableDescription().GetIndexDescriptions();
UNIT_ASSERT_EQUAL(indexes.size(), 1);
UNIT_ASSERT_EQUAL(indexes[0].GetIndexName(), "index");
std::vector<std::string> indexKeyColumns{"user", "emb"};
UNIT_ASSERT_EQUAL(indexes[0].GetIndexColumns(), indexKeyColumns);
std::vector<std::string> indexDataColumns;
if (Covered) {
indexDataColumns = {"emb", "data"};
}
UNIT_ASSERT_EQUAL(indexes[0].GetDataColumns(), indexDataColumns);
const auto& settings = std::get<TKMeansTreeSettings>(indexes[0].GetIndexSettings());
UNIT_ASSERT_EQUAL(settings.Settings.Metric, NYdb::NTable::TVectorIndexSettings::EMetric::CosineDistance);
UNIT_ASSERT_EQUAL(settings.Settings.VectorType, NYdb::NTable::TVectorIndexSettings::EVectorType::Uint8);
UNIT_ASSERT_EQUAL(settings.Settings.VectorDimension, 2);
UNIT_ASSERT_EQUAL(settings.Levels, 2);
UNIT_ASSERT_EQUAL(settings.Clusters, 2);
}
DoPositiveQueriesPrefixedVectorIndexOrderByCosine(session, Covered);
}

}

}
Expand Down
2 changes: 2 additions & 0 deletions ydb/core/protos/tx_datashard.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,8 @@ message TEvPrefixKMeansRequest {
optional uint32 PrefixColumns = 17;

optional NKikimrIndexBuilder.TIndexBuildScanSettings ScanSettings = 18;

repeated string SourcePrimaryKeyColumns = 19;
}

message TEvPrefixKMeansResponse {
Expand Down
64 changes: 20 additions & 44 deletions ydb/core/tx/datashard/build_index/kmeans_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,51 +45,22 @@ void AddRowToLevel(TBufferData& buffer, TClusterId parent, TClusterId child, con
buffer.AddRow(TSerializedCellVec{pk}, TSerializedCellVec::Serialize(data));
}

void AddRowMainToBuild(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row) {
EnsureNoPostingParentFlag(parent);

std::array<TCell, 1> cells;
cells[0] = TCell::Make(parent);
auto pk = TSerializedCellVec::Serialize(cells);
TSerializedCellVec::UnsafeAppendCells(key, pk);
buffer.AddRow(TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(row),
TSerializedCellVec{key});
}

void AddRowMainToPosting(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 dataPos)
{
parent = SetPostingParentFlag(parent);

std::array<TCell, 1> cells;
cells[0] = TCell::Make(parent);
auto pk = TSerializedCellVec::Serialize(cells);
TSerializedCellVec::UnsafeAppendCells(key, pk);
buffer.AddRow(TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(row.Slice(dataPos)),
TSerializedCellVec{key});
}

void AddRowBuildToBuild(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 prefixColumns)
{
EnsureNoPostingParentFlag(parent);
void AddRowToData(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> sourcePk,
TArrayRef<const TCell> dataColumns, TArrayRef<const TCell> origKey, bool isPostingLevel) {
if (isPostingLevel) {
parent = SetPostingParentFlag(parent);
} else {
EnsureNoPostingParentFlag(parent);
}

std::array<TCell, 1> cells;
cells[0] = TCell::Make(parent);
auto pk = TSerializedCellVec::Serialize(cells);
TSerializedCellVec::UnsafeAppendCells(key.Slice(prefixColumns), pk);
buffer.AddRow(TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(row),
TSerializedCellVec{key});
}

void AddRowBuildToPosting(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 dataPos, ui32 prefixColumns)
{
parent = SetPostingParentFlag(parent);
TSerializedCellVec::UnsafeAppendCells(sourcePk, pk);

std::array<TCell, 1> cells;
cells[0] = TCell::Make(parent);
auto pk = TSerializedCellVec::Serialize(cells);
TSerializedCellVec::UnsafeAppendCells(key.Slice(prefixColumns), pk);
buffer.AddRow(TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(row.Slice(dataPos)),
TSerializedCellVec{key});
buffer.AddRow(TSerializedCellVec{std::move(pk)},
TSerializedCellVec::Serialize(dataColumns),
TSerializedCellVec{origKey});
}

TTags MakeScanTags(const TUserTable& table, const TProtoStringType& embedding,
Expand All @@ -114,12 +85,11 @@ TTags MakeScanTags(const TUserTable& table, const TProtoStringType& embedding,

std::shared_ptr<NTxProxy::TUploadTypes> MakeOutputTypes(const TUserTable& table, NKikimrTxDataShard::EKMeansState uploadState,
const TProtoStringType& embedding, const google::protobuf::RepeatedPtrField<TProtoStringType>& data,
ui32 prefixColumns)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

тут наверно тоже два метода бы, но не уверен, посмотри сам как они логически вызываются

и рядом со всеми методами нарисуй плиз схему строки

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

да вот тут мне кажется можно и оставить, тут относительно простая логика - ну типа если даны явно колонки первичного ключа то берём их из pkColumns...

const google::protobuf::RepeatedPtrField<TProtoStringType>& pkColumns)
{
auto types = GetAllTypes(table);

auto result = std::make_shared<NTxProxy::TUploadTypes>();
result->reserve(1 + 1 + std::min((table.KeyColumnTypes.size() - prefixColumns) + data.size(), types.size()));

Ydb::Type type;
type.set_type_id(NTableIndex::ClusterIdType);
Expand All @@ -133,8 +103,14 @@ std::shared_ptr<NTxProxy::TUploadTypes> MakeOutputTypes(const TUserTable& table,
types.erase(it);
}
};
for (const auto& column : table.KeyColumnIds | std::views::drop(prefixColumns)) {
addType(table.Columns.at(column).Name);
if (pkColumns.size()) {
for (const auto& column : pkColumns) {
addType(column);
}
} else {
for (const auto& column : table.KeyColumnIds) {
addType(table.Columns.at(column).Name);
}
}
switch (uploadState) {
case NKikimrTxDataShard::EKMeansState::UPLOAD_MAIN_TO_BUILD:
Expand Down
13 changes: 4 additions & 9 deletions ydb/core/tx/datashard/build_index/kmeans_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,21 +140,16 @@ struct TMaxInnerProductSimilarity : TMetric<TCoord> {

void AddRowToLevel(TBufferData& buffer, TClusterId parent, TClusterId child, const TString& embedding, bool isPostingLevel);

void AddRowMainToBuild(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row);
void AddRowToData(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> sourcePk,
TArrayRef<const TCell> dataColumns, TArrayRef<const TCell> origKey, bool isPostingLevel);

void AddRowMainToPosting(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 dataPos);

void AddRowBuildToBuild(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 prefixColumns = 1);

void AddRowBuildToPosting(TBufferData& buffer, TClusterId parent, TArrayRef<const TCell> key, TArrayRef<const TCell> row, ui32 dataPos, ui32 prefixColumns = 1);

TTags MakeScanTags(const TUserTable& table, const TProtoStringType& embedding,
TTags MakeScanTags(const TUserTable& table, const TProtoStringType& embedding,
const google::protobuf::RepeatedPtrField<TProtoStringType>& data, ui32& embeddingPos,
ui32& dataPos, NTable::TTag& embeddingTag);

std::shared_ptr<NTxProxy::TUploadTypes> MakeOutputTypes(const TUserTable& table, NKikimrTxDataShard::EKMeansState uploadState,
const TProtoStringType& embedding, const google::protobuf::RepeatedPtrField<TProtoStringType>& data,
ui32 prefixColumns = 0);
const google::protobuf::RepeatedPtrField<TProtoStringType>& pkColumns = {});

void MakeScan(auto& record, const auto& createScan, const auto& badRequest)
{
Expand Down
8 changes: 4 additions & 4 deletions ydb/core/tx/datashard/build_index/local_kmeans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,28 +462,28 @@ class TLocalKMeansScan final : public TLocalKMeansScanBase {
void FeedMainToBuild(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
AddRowMainToBuild(*OutputBuf, Child + *pos, key, row);
AddRowToData(*OutputBuf, Child + *pos, key, row, key, false);
}
}

void FeedMainToPosting(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
AddRowMainToPosting(*OutputBuf, Child + *pos, key, row, DataPos);
AddRowToData(*OutputBuf, Child + *pos, key, row.Slice(DataPos), key, true);
}
}

void FeedBuildToBuild(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
AddRowBuildToBuild(*OutputBuf, Child + *pos, key, row);
AddRowToData(*OutputBuf, Child + *pos, key.Slice(1), row, key, false);
}
}

void FeedBuildToPosting(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
AddRowBuildToPosting(*OutputBuf, Child + *pos, key, row, DataPos);
AddRowToData(*OutputBuf, Child + *pos, key.Slice(1), row.Slice(DataPos), key, true);
}
}

Expand Down
31 changes: 25 additions & 6 deletions ydb/core/tx/datashard/build_index/prefix_kmeans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using namespace NKMeans;
*
* Request:
* - The client sends TEvPrefixKMeansRequest with:
* - Child: base ID from which new cluster IDs are assigned within this request.
* - Child: base ID from which new cluster IDs are assigned within this request.
* - Each prefix group processed will be assigned cluster IDs starting at Child + 1.
* - For a request with K clusters per prefix, the IDs used for the first prefix group are
* (Child + 1) to (Child + K), and the parent ID for these is Child.
Expand Down Expand Up @@ -93,6 +93,9 @@ class TPrefixKMeansScanBase: public TActor<TPrefixKMeansScanBase>, public IActor

// FIXME: save PrefixRows as std::vector<std::pair<TSerializedCellVec, TSerializedCellVec>> to avoid parsing
const ui32 PrefixColumns;
// for PrefixKMeans, original table's primary key columns are passed separately,
// because the prefix table contains them in a different order if they are both in PK and in the prefix
const ui32 DataColumnCount;
TSerializedCellVec Prefix;
TBufferData PrefixRows;
bool IsFirstPrefixFeed = true;
Expand Down Expand Up @@ -126,10 +129,14 @@ class TPrefixKMeansScanBase: public TActor<TPrefixKMeansScanBase>, public IActor
, ResponseActorId{responseActorId}
, Response{std::move(response)}
, PrefixColumns{request.GetPrefixColumns()}
, DataColumnCount{(ui32)request.GetDataColumns().size()}
{
const auto& embedding = request.GetEmbeddingColumn();
const auto& data = request.GetDataColumns();
ScanTags = MakeScanTags(table, embedding, data, EmbeddingPos, DataPos, EmbeddingTag);
TVector<TString> data{request.GetDataColumns().begin(), request.GetDataColumns().end()};
for (auto & col: request.GetSourcePrimaryKeyColumns()) {
data.push_back(col);
}
ScanTags = MakeScanTags(table, embedding, {data.begin(), data.end()}, EmbeddingPos, DataPos, EmbeddingTag);
Lead.To(ScanTags, {}, NTable::ESeek::Lower);
{
Ydb::Type type;
Expand All @@ -141,7 +148,11 @@ class TPrefixKMeansScanBase: public TActor<TPrefixKMeansScanBase>, public IActor
(*levelTypes)[2] = {NTableIndex::NTableVectorKmeansTreeIndex::CentroidColumn, type};
LevelBuf = Uploader.AddDestination(request.GetLevelName(), std::move(levelTypes));
}
OutputBuf = Uploader.AddDestination(request.GetOutputName(), MakeOutputTypes(table, UploadState, embedding, data, PrefixColumns));
{
auto outputTypes = MakeOutputTypes(table, UploadState, embedding,
{data.begin(), data.begin()+request.GetDataColumns().size()}, request.GetSourcePrimaryKeyColumns());
OutputBuf = Uploader.AddDestination(request.GetOutputName(), outputTypes);
}
{
auto types = GetAllTypes(table);

Expand Down Expand Up @@ -480,14 +491,14 @@ class TPrefixKMeansScan final : public TPrefixKMeansScanBase {
void FeedBuildToBuild(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
AddRowBuildToBuild(*OutputBuf, Child + *pos, key, row, PrefixColumns);
AddRowToData(*OutputBuf, Child + *pos, row.Slice(DataPos+DataColumnCount), row.Slice(0, DataPos+DataColumnCount), key, false);
}
}

void FeedBuildToPosting(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
AddRowBuildToPosting(*OutputBuf, Child + *pos, key, row, DataPos, PrefixColumns);
AddRowToData(*OutputBuf, Child + *pos, row.Slice(DataPos+DataColumnCount), row.Slice(DataPos, DataColumnCount), key, true);
}
}

Expand Down Expand Up @@ -624,6 +635,14 @@ void TDataShard::HandleSafe(TEvDataShard::TEvPrefixKMeansRequest::TPtr& ev, cons
if (request.GetPrefixColumns() > userTable.KeyColumnIds.size()) {
badRequest(TStringBuilder() << "Should not be requested on more than " << userTable.KeyColumnIds.size() << " prefix columns");
}
if (request.GetSourcePrimaryKeyColumns().size() == 0) {
badRequest("Request should include source primary key columns");
}
for (auto pkColumn : request.GetSourcePrimaryKeyColumns()) {
if (!tags.contains(pkColumn)) {
badRequest(TStringBuilder() << "Unknown source primary key column: " << pkColumn);
}
}

if (trySendBadRequest()) {
return;
Expand Down
8 changes: 4 additions & 4 deletions ydb/core/tx/datashard/build_index/reshuffle_kmeans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,28 +308,28 @@ class TReshuffleKMeansScan final : public TReshuffleKMeansScanBase {
void FeedMainToBuild(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
AddRowMainToBuild(*OutputBuf, Child + *pos, key, row);
AddRowToData(*OutputBuf, Child + *pos, key, row, key, false);
}
}

void FeedMainToPosting(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
AddRowMainToPosting(*OutputBuf, Child + *pos, key, row, DataPos);
AddRowToData(*OutputBuf, Child + *pos, key, row.Slice(DataPos), key, true);
}
}

void FeedBuildToBuild(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
AddRowBuildToBuild(*OutputBuf, Child + *pos, key, row);
AddRowToData(*OutputBuf, Child + *pos, key.Slice(1), row, key, false);
}
}

void FeedBuildToPosting(TArrayRef<const TCell> key, TArrayRef<const TCell> row)
{
if (auto pos = Clusters.FindCluster(row, EmbeddingPos); pos) {
AddRowBuildToPosting(*OutputBuf, Child + *pos, key, row, DataPos);
AddRowToData(*OutputBuf, Child + *pos, key.Slice(1), row.Slice(DataPos), key, true);
}
}
};
Expand Down
Loading
Loading