From 85681169bc82e840b10bbd22dcdc0465c551bf24 Mon Sep 17 00:00:00 2001 From: Wenxuan Date: Mon, 8 Jul 2024 14:44:20 +0800 Subject: [PATCH] storage: Optimize Vector Index (#232) Signed-off-by: Wish Co-authored-by: Lloyd-Pottiger <60744015+Lloyd-Pottiger@users.noreply.github.com> --- .gitmodules | 3 + contrib/CMakeLists.txt | 4 + contrib/hdf5-cmake/.gitignore | 1 + contrib/hdf5-cmake/CMakeLists.txt | 41 ++++ contrib/highfive | 1 + contrib/highfive-cmake/CMakeLists.txt | 18 ++ dbms/CMakeLists.txt | 13 +- .../BitmapFilter/BitmapFilterView.h | 10 +- .../DeltaMergeStore_InternalSegment.cpp | 51 +++-- .../DeltaMerge/File/DMFileIndexWriter.cpp | 17 +- .../DeltaMerge/File/DMFileIndexWriter.h | 13 +- .../Storages/DeltaMerge/Index/VectorIndex.h | 10 +- .../Index/VectorIndexHNSW/Index.cpp | 47 ++++- .../DeltaMerge/Index/VectorIndexHNSW/Index.h | 8 +- .../DeltaMerge/tests/bench_dataset/.gitignore | 1 + .../DeltaMerge/tests/bench_dataset/README.md | 7 + .../DeltaMerge/tests/bench_vector_index.cpp | 98 ++++++++++ .../tests/bench_vector_index_utils.h | 178 ++++++++++++++++++ .../tests/gtest_dm_vector_index.cpp | 26 +-- 19 files changed, 506 insertions(+), 41 deletions(-) create mode 100644 contrib/hdf5-cmake/.gitignore create mode 100644 contrib/hdf5-cmake/CMakeLists.txt create mode 160000 contrib/highfive create mode 100644 contrib/highfive-cmake/CMakeLists.txt create mode 100644 dbms/src/Storages/DeltaMerge/tests/bench_dataset/.gitignore create mode 100644 dbms/src/Storages/DeltaMerge/tests/bench_dataset/README.md create mode 100644 dbms/src/Storages/DeltaMerge/tests/bench_vector_index.cpp create mode 100644 dbms/src/Storages/DeltaMerge/tests/bench_vector_index_utils.h diff --git a/.gitmodules b/.gitmodules index bf4cfbb78b5..5f5aa5e778e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -152,3 +152,6 @@ [submodule "contrib/simsimd"] path = contrib/simsimd url = https://github.com/ashvardanian/SimSIMD +[submodule "contrib/highfive"] + path = contrib/highfive + url = https://github.com/BlueBrain/HighFive diff --git a/contrib/CMakeLists.txt b/contrib/CMakeLists.txt index d4ad3cc2615..61e893a2577 100644 --- a/contrib/CMakeLists.txt +++ b/contrib/CMakeLists.txt @@ -202,3 +202,7 @@ add_subdirectory(fastpforlib) add_subdirectory(usearch-cmake) add_subdirectory(simsimd-cmake) + +add_subdirectory(hdf5-cmake) + +add_subdirectory(highfive-cmake) diff --git a/contrib/hdf5-cmake/.gitignore b/contrib/hdf5-cmake/.gitignore new file mode 100644 index 00000000000..b52e5847f5f --- /dev/null +++ b/contrib/hdf5-cmake/.gitignore @@ -0,0 +1 @@ +/download/* diff --git a/contrib/hdf5-cmake/CMakeLists.txt b/contrib/hdf5-cmake/CMakeLists.txt new file mode 100644 index 00000000000..0f40c7b4f52 --- /dev/null +++ b/contrib/hdf5-cmake/CMakeLists.txt @@ -0,0 +1,41 @@ +include(ExternalProject) + +# hdf5 is too large. Instead of adding as a submodule, let's simply download from GitHub. +ExternalProject_Add(hdf5-external + PREFIX ${CMAKE_CURRENT_BINARY_DIR} + DOWNLOAD_DIR ${TiFlash_SOURCE_DIR}/contrib/hdf5-cmake/download + URL https://github.com/HDFGroup/hdf5/archive/refs/tags/hdf5_1.14.4.3.zip + URL_HASH MD5=bc987d22e787290127aacd7b99b4f31e + CMAKE_ARGS + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_INSTALL_PREFIX= + -DBUILD_STATIC_LIBS=ON + -DBUILD_SHARED_LIBS=OFF + -DBUILD_TESTING=OFF + -DHDF5_BUILD_HL_LIB=OFF + -DHDF5_BUILD_TOOLS=OFF + -DHDF5_BUILD_CPP_LIB=ON + -DHDF5_BUILD_EXAMPLES=OFF + -DHDF5_ENABLE_Z_LIB_SUPPORT=OFF + -DHDF5_ENABLE_SZIP_SUPPORT=OFF + BUILD_BYPRODUCTS /lib/${CMAKE_FIND_LIBRARY_PREFIXES}hdf5.a # Workaround for Ninja + USES_TERMINAL_DOWNLOAD TRUE + USES_TERMINAL_CONFIGURE TRUE + USES_TERMINAL_BUILD TRUE + USES_TERMINAL_INSTALL TRUE + EXCLUDE_FROM_ALL TRUE + DOWNLOAD_EXTRACT_TIMESTAMP TRUE +) + +ExternalProject_Get_Property(hdf5-external INSTALL_DIR) + +add_library(tiflash_contrib::hdf5 STATIC IMPORTED GLOBAL) +set_target_properties(tiflash_contrib::hdf5 PROPERTIES + IMPORTED_LOCATION ${INSTALL_DIR}/lib/${CMAKE_FIND_LIBRARY_PREFIXES}hdf5.a +) +add_dependencies(tiflash_contrib::hdf5 hdf5-external) + +file(MAKE_DIRECTORY ${INSTALL_DIR}/include) +target_include_directories(tiflash_contrib::hdf5 SYSTEM INTERFACE + ${INSTALL_DIR}/include +) diff --git a/contrib/highfive b/contrib/highfive new file mode 160000 index 00000000000..0d0259e823a --- /dev/null +++ b/contrib/highfive @@ -0,0 +1 @@ +Subproject commit 0d0259e823a0e8aee2f036ba738c703ac4a0721c diff --git a/contrib/highfive-cmake/CMakeLists.txt b/contrib/highfive-cmake/CMakeLists.txt new file mode 100644 index 00000000000..59ca95a64ca --- /dev/null +++ b/contrib/highfive-cmake/CMakeLists.txt @@ -0,0 +1,18 @@ +set(HIGHFIVE_PROJECT_DIR "${TiFlash_SOURCE_DIR}/contrib/highfive") +set(HIGHFIVE_SOURCE_DIR "${HIGHFIVE_PROJECT_DIR}/include") + +if (NOT EXISTS "${HIGHFIVE_SOURCE_DIR}/highfive/highfive.hpp") + message (FATAL_ERROR "submodule contrib/highfive not found") +endif() + +add_library(_highfive INTERFACE) + +target_include_directories(_highfive SYSTEM INTERFACE + ${HIGHFIVE_SOURCE_DIR} +) + +target_link_libraries(_highfive INTERFACE + tiflash_contrib::hdf5 +) + +add_library(tiflash_contrib::highfive ALIAS _highfive) diff --git a/dbms/CMakeLists.txt b/dbms/CMakeLists.txt index ab99ca84159..325a0573a6b 100644 --- a/dbms/CMakeLists.txt +++ b/dbms/CMakeLists.txt @@ -388,7 +388,18 @@ if (ENABLE_TESTS) ) target_include_directories(bench_dbms BEFORE PRIVATE ${SPARCEHASH_INCLUDE_DIR} ${benchmark_SOURCE_DIR}/include) target_compile_definitions(bench_dbms PUBLIC DBMS_PUBLIC_GTEST) - target_link_libraries(bench_dbms gtest dbms test_util_bench_main benchmark tiflash_functions server_for_test delta_merge kvstore tiflash_aggregate_functions) + target_link_libraries(bench_dbms + gtest + benchmark + tiflash_contrib::highfive + + dbms + test_util_bench_main + tiflash_functions + server_for_test + delta_merge + tiflash_aggregate_functions + kvstore) add_check(bench_dbms) endif () diff --git a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterView.h b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterView.h index 09804aaf077..02042911c7c 100644 --- a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterView.h +++ b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilterView.h @@ -37,11 +37,19 @@ class BitmapFilterView RUNTIME_CHECK(filter_offset + filter_size <= filter->size(), filter_offset, filter_size, filter->size()); } + /** + * @brief Create a BitmapFilter and construct a BitmapFilterView with it. + * Should be only used in tests. + */ + static BitmapFilterView createWithFilter(UInt32 size, bool default_value) + { + return BitmapFilterView(std::make_shared(size, default_value), 0, size); + } + // Caller should ensure n in [0, size). inline bool get(UInt32 n) const { return filter->get(filter_offset + n); } inline bool operator[](UInt32 n) const { return get(n); } - inline UInt32 size() const { return filter_size; } inline UInt32 offset() const { return filter_offset; } diff --git a/dbms/src/Storages/DeltaMerge/DeltaMergeStore_InternalSegment.cpp b/dbms/src/Storages/DeltaMerge/DeltaMergeStore_InternalSegment.cpp index 00d3f012b8a..3226b6b618b 100644 --- a/dbms/src/Storages/DeltaMerge/DeltaMergeStore_InternalSegment.cpp +++ b/dbms/src/Storages/DeltaMerge/DeltaMergeStore_InternalSegment.cpp @@ -39,6 +39,11 @@ extern const Metric DT_SnapshotOfSegmentIngest; extern const Metric DT_SnapshotOfSegmentIngestIndex; } // namespace CurrentMetrics +namespace DB::ErrorCodes +{ +extern const int ABORTED; +} + namespace DB::DM { @@ -673,18 +678,20 @@ void DeltaMergeStore::segmentEnsureStableIndex( RUNTIME_CHECK(dm_files.size() == 1); // size > 1 is currently not supported. const auto & dm_file = dm_files[0]; - // 2. Check whether the DMFile has been referenced by any valid segment. - { + auto is_file_valid = [this, dm_file] { std::shared_lock lock(read_write_mutex); auto segment_ids = dmfile_id_to_segment_ids.get(dm_file->fileId()); - if (segment_ids.empty()) - { - LOG_DEBUG( - log, - "EnsureStableIndex - Give up because no segment to update, source_segment={}", - source_segment_info); - return; - } + return !segment_ids.empty(); + }; + + // 2. Check whether the DMFile has been referenced by any valid segment. + if (!is_file_valid()) + { + LOG_DEBUG( + log, + "EnsureStableIndex - Give up because no segment to update, source_segment={}", + source_segment_info); + return; } LOG_INFO( @@ -700,7 +707,29 @@ void DeltaMergeStore::segmentEnsureStableIndex( .dm_files = dm_files, .dm_context = dm_context, }); - auto new_dmfiles = iw.build(); + + DMFiles new_dmfiles{}; + + try + { + // When file is not valid we need to abort the index build. + new_dmfiles = iw.build(is_file_valid); + } + catch (const Exception & e) + { + if (e.code() == ErrorCodes::ABORTED) + { + LOG_INFO( + log, + "EnsureStableIndex - Build index aborted because DMFile is no longer valid, dm_files={} " + "source_segment={}", + DMFile::info(dm_files), + source_segment_info); + return; + } + throw; + } + RUNTIME_CHECK(!new_dmfiles.empty()); LOG_INFO( diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileIndexWriter.cpp b/dbms/src/Storages/DeltaMerge/File/DMFileIndexWriter.cpp index aa879d4b206..7e438773438 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileIndexWriter.cpp +++ b/dbms/src/Storages/DeltaMerge/File/DMFileIndexWriter.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -22,6 +23,11 @@ #include #include +namespace DB::ErrorCodes +{ +extern const int ABORTED; +} + namespace DB::DM { @@ -63,7 +69,7 @@ DMFileIndexWriter::LocalIndexBuildInfo DMFileIndexWriter::getLocalIndexBuildInfo return build; } -size_t DMFileIndexWriter::buildIndexForFile(const DMFilePtr & dm_file_mutable) const +size_t DMFileIndexWriter::buildIndexForFile(const DMFilePtr & dm_file_mutable, ProceedCheckFn should_proceed) const { const auto column_defines = dm_file_mutable->getColumnDefines(); const auto del_cd_iter = std::find_if(column_defines.cbegin(), column_defines.cend(), [](const ColumnDefine & cd) { @@ -128,6 +134,9 @@ size_t DMFileIndexWriter::buildIndexForFile(const DMFilePtr & dm_file_mutable) c // Read all blocks and build index while (true) { + if (!should_proceed()) + throw Exception(ErrorCodes::ABORTED, "Index build is interrupted"); + auto block = read_stream->read(); if (!block) break; @@ -146,7 +155,7 @@ size_t DMFileIndexWriter::buildIndexForFile(const DMFilePtr & dm_file_mutable) c const auto & col_with_type_and_name = block.safeGetByPosition(col_idx + 1); RUNTIME_CHECK(col_with_type_and_name.column_id == read_columns[col_idx + 1].id); const auto & col = col_with_type_and_name.column; - index_builder->addBlock(*col, del_mark); + index_builder->addBlock(*col, del_mark, should_proceed); } } @@ -187,7 +196,7 @@ size_t DMFileIndexWriter::buildIndexForFile(const DMFilePtr & dm_file_mutable) c return total_built_index_bytes; } -DMFiles DMFileIndexWriter::build() const +DMFiles DMFileIndexWriter::build(ProceedCheckFn should_proceed) const { RUNTIME_CHECK(!built); // Create a clone of existing DMFile instances by using DMFile::restore, @@ -214,7 +223,7 @@ DMFiles DMFileIndexWriter::build() const for (const auto & cloned_dmfile : cloned_dm_files) { - auto index_bytes = buildIndexForFile(cloned_dmfile); + auto index_bytes = buildIndexForFile(cloned_dmfile, should_proceed); if (auto data_store = options.dm_context.global_context.getSharedContextDisagg()->remote_data_store; !data_store) { diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileIndexWriter.h b/dbms/src/Storages/DeltaMerge/File/DMFileIndexWriter.h index 8604e682186..09b64999809 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileIndexWriter.h +++ b/dbms/src/Storages/DeltaMerge/File/DMFileIndexWriter.h @@ -57,16 +57,23 @@ class DMFileIndexWriter const DMContext & dm_context; }; + using ProceedCheckFn = std::function; + explicit DMFileIndexWriter(const Options & options) : logger(Logger::get()) , options(options) {} - // Note: This method can only be called once. - DMFiles build() const; + // Note: You cannot call build() multiple times, as duplicate meta version will result in exceptions. + DMFiles build(ProceedCheckFn should_proceed) const; + + DMFiles build() const + { + return build([]() { return true; }); + } private: - size_t buildIndexForFile(const DMFilePtr & dm_file_mutable) const; + size_t buildIndexForFile(const DMFilePtr & dm_file_mutable, ProceedCheckFn should_proceed) const; private: const LoggerPtr logger; diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex.h index 05457611336..5302cdd3788 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndex.h +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex.h @@ -35,6 +35,8 @@ class VectorIndexBuilder /// The key is the row's offset in the DMFile. using Key = UInt32; + using ProceedCheckFn = std::function; + public: static VectorIndexBuilderPtr create(const TiDB::VectorIndexDefinitionPtr & definition); @@ -47,7 +49,11 @@ class VectorIndexBuilder virtual ~VectorIndexBuilder() = default; - virtual void addBlock(const IColumn & column, const ColumnVector * del_mark) = 0; + virtual void addBlock( // + const IColumn & column, + const ColumnVector * del_mark, + ProceedCheckFn should_proceed) + = 0; virtual void save(std::string_view path) const = 0; @@ -80,6 +86,8 @@ class VectorIndexViewer // Invalid rows in `valid_rows` will be discared when applying the search virtual std::vector search(const ANNQueryInfoPtr & queryInfo, const RowFilter & valid_rows) const = 0; + virtual size_t size() const = 0; + // Get the value (i.e. vector content) of a Key. virtual void get(Key key, std::vector & out) const = 0; diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp index 4ba17fd4549..e2f2d90b476 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp @@ -31,6 +31,7 @@ namespace DB::ErrorCodes extern const int INCORRECT_DATA; extern const int INCORRECT_QUERY; extern const int CANNOT_ALLOCATE_MEMORY; +extern const int ABORTED; } // namespace DB::ErrorCodes namespace DB::DM @@ -62,11 +63,14 @@ VectorIndexHNSWBuilder::VectorIndexHNSWBuilder(const TiDB::VectorIndexDefinition definition_->dimension, getUSearchMetricKind(definition->distance_metric)))) { - RUNTIME_CHECK(definition_->kind == tipb::VectorIndexKind::HNSW); + RUNTIME_CHECK(definition_->kind == kind()); GET_METRIC(tiflash_vector_index_active_instances, type_build).Increment(); } -void VectorIndexHNSWBuilder::addBlock(const IColumn & column, const ColumnVector * del_mark) +void VectorIndexHNSWBuilder::addBlock( + const IColumn & column, + const ColumnVector * del_mark, + ProceedCheckFn should_proceed) { // Note: column may be nullable. const ColumnArray * col_array; @@ -88,11 +92,21 @@ void VectorIndexHNSWBuilder::addBlock(const IColumn & column, const ColumnVector Stopwatch w; SCOPE_EXIT({ total_duration += w.elapsedSeconds(); }); + Stopwatch w_proceed_check(CLOCK_MONOTONIC_COARSE); + for (int i = 0, i_max = col_array->size(); i < i_max; ++i) { auto row_offset = added_rows; added_rows++; + if (unlikely(i % 100 == 0 && w_proceed_check.elapsedSeconds() > 0.5)) + { + // The check of should_proceed could be non-trivial, so do it not too often. + w_proceed_check.restart(); + if (!should_proceed()) + throw Exception(ErrorCodes::ABORTED, "Index build is interrupted"); + } + // Ignore rows with del_mark, as the column values are not meaningful. if (del_mark_data != nullptr && (*del_mark_data)[i]) continue; @@ -139,9 +153,14 @@ VectorIndexHNSWBuilder::~VectorIndexHNSWBuilder() GET_METRIC(tiflash_vector_index_active_instances, type_build).Decrement(); } +tipb::VectorIndexKind VectorIndexHNSWBuilder::kind() +{ + return tipb::VectorIndexKind::HNSW; +} + VectorIndexViewerPtr VectorIndexHNSWViewer::view(const dtpb::VectorIndexFileProps & file_props, std::string_view path) { - RUNTIME_CHECK(file_props.index_kind() == tipb::VectorIndexKind_Name(tipb::VectorIndexKind::HNSW)); + RUNTIME_CHECK(file_props.index_kind() == tipb::VectorIndexKind_Name(kind())); tipb::VectorDistanceMetric metric; RUNTIME_CHECK(tipb::VectorDistanceMetric_Parse(file_props.distance_metric(), &metric)); @@ -151,9 +170,15 @@ VectorIndexViewerPtr VectorIndexHNSWViewer::view(const dtpb::VectorIndexFileProp SCOPE_EXIT({ GET_METRIC(tiflash_vector_index_duration, type_view).Observe(w.elapsedSeconds()); }); auto vi = std::make_shared(file_props); - vi->index = USearchImplType::make(unum::usearch::metric_punned_t( // - file_props.dimensions(), - getUSearchMetricKind(metric))); + vi->index = USearchImplType::make( + unum::usearch::metric_punned_t( // + file_props.dimensions(), + getUSearchMetricKind(metric)), + unum::usearch::index_dense_config_t( + unum::usearch::default_connectivity(), + unum::usearch::default_expansion_add(), + 16 /* default is 64 */)); + auto result = vi->index.view(unum::usearch::memory_mapped_file_t(path.data())); RUNTIME_CHECK_MSG(result, "Failed to load vector index: {}", result.error.what()); @@ -236,6 +261,11 @@ std::vector VectorIndexHNSWViewer::search( return keys; } +size_t VectorIndexHNSWViewer::size() const +{ + return index.size(); +} + void VectorIndexHNSWViewer::get(Key key, std::vector & out) const { out.resize(file_props.dimensions()); @@ -254,4 +284,9 @@ VectorIndexHNSWViewer::~VectorIndexHNSWViewer() GET_METRIC(tiflash_vector_index_active_instances, type_view).Decrement(); } +tipb::VectorIndexKind VectorIndexHNSWViewer::kind() +{ + return tipb::VectorIndexKind::HNSW; +} + } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h index 38f11a6f760..1804fbe80ea 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h @@ -27,11 +27,13 @@ using USearchImplType = unum::usearch:: class VectorIndexHNSWBuilder : public VectorIndexBuilder { public: + static tipb::VectorIndexKind kind(); + explicit VectorIndexHNSWBuilder(const TiDB::VectorIndexDefinitionPtr & definition_); ~VectorIndexHNSWBuilder() override; - void addBlock(const IColumn & column, const ColumnVector * del_mark) override; + void addBlock(const IColumn & column, const ColumnVector * del_mark, ProceedCheckFn should_proceed) override; void save(std::string_view path) const override; @@ -48,12 +50,16 @@ class VectorIndexHNSWViewer : public VectorIndexViewer public: static VectorIndexViewerPtr view(const dtpb::VectorIndexFileProps & props, std::string_view path); + static tipb::VectorIndexKind kind(); + explicit VectorIndexHNSWViewer(const dtpb::VectorIndexFileProps & props); ~VectorIndexHNSWViewer() override; std::vector search(const ANNQueryInfoPtr & query_info, const RowFilter & valid_rows) const override; + size_t size() const override; + void get(Key key, std::vector & out) const override; private: diff --git a/dbms/src/Storages/DeltaMerge/tests/bench_dataset/.gitignore b/dbms/src/Storages/DeltaMerge/tests/bench_dataset/.gitignore new file mode 100644 index 00000000000..300cf170ff5 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/tests/bench_dataset/.gitignore @@ -0,0 +1 @@ +*.hdf5 diff --git a/dbms/src/Storages/DeltaMerge/tests/bench_dataset/README.md b/dbms/src/Storages/DeltaMerge/tests/bench_dataset/README.md new file mode 100644 index 00000000000..ca8e5ec402d --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/tests/bench_dataset/README.md @@ -0,0 +1,7 @@ +# Benchmark Datasets + +To prepare datasets: + +```shell +wget https://ann-benchmarks.com/fashion-mnist-784-euclidean.hdf5 +``` diff --git a/dbms/src/Storages/DeltaMerge/tests/bench_vector_index.cpp b/dbms/src/Storages/DeltaMerge/tests/bench_vector_index.cpp new file mode 100644 index 00000000000..4328625b361 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/tests/bench_vector_index.cpp @@ -0,0 +1,98 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS,n +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace DB::DM::bench +{ + +static void VectorIndexBuild(::benchmark::State & state) +try +{ + const auto & dataset = DatasetMnist::get(); + + auto train_data = dataset.buildDataTrainColumn(/* max_rows= */ 10000); + auto index_def = dataset.createIndexDef(tipb::VectorIndexKind::HNSW); + for (auto _ : state) + { + auto builder = std::make_unique(index_def); + builder->addBlock(*train_data, nullptr, []() { return true; }); + } +} +CATCH + +static void VectorIndexSearchTop10(::benchmark::State & state) +try +{ + const auto & dataset = DatasetMnist::get(); + + auto index_path = DB::tests::TiFlashTestEnv::getTemporaryPath("vector_search_top_10/vector_index.idx"); + VectorIndexBenchUtils::saveVectorIndex( // + index_path, + dataset, + /* max_rows= */ 10000); + + auto viewer = VectorIndexBenchUtils::viewVectorIndex(index_path, dataset); + + std::random_device rd; + std::mt19937 rng(rd()); + std::uniform_int_distribution dist(0, dataset.dataTestSize() - 1); + + for (auto _ : state) + { + auto test_index = dist(rng); + const auto & query_vector = DatasetMnist::get().dataTestAt(test_index); + auto keys = VectorIndexBenchUtils::queryTopK(viewer, query_vector, 10, state); + RUNTIME_CHECK(keys.size() == 10); + } +} +CATCH + +static void VectorIndexSearchTop100(::benchmark::State & state) +try +{ + const auto & dataset = DatasetMnist::get(); + + auto index_path = DB::tests::TiFlashTestEnv::getTemporaryPath("vector_search_top_10/vector_index.idx"); + VectorIndexBenchUtils::saveVectorIndex( // + index_path, + dataset, + /* max_rows= */ 10000); + + auto viewer = VectorIndexBenchUtils::viewVectorIndex(index_path, dataset); + + std::random_device rd; + std::mt19937 rng(rd()); + std::uniform_int_distribution dist(0, dataset.dataTestSize() - 1); + + for (auto _ : state) + { + auto test_index = dist(rng); + const auto & query_vector = DatasetMnist::get().dataTestAt(test_index); + auto keys = VectorIndexBenchUtils::queryTopK(viewer, query_vector, 100, state); + RUNTIME_CHECK(keys.size() == 100); + } +} +CATCH + +BENCHMARK(VectorIndexBuild); + +BENCHMARK(VectorIndexSearchTop10); + +BENCHMARK(VectorIndexSearchTop100); + +} // namespace DB::DM::bench diff --git a/dbms/src/Storages/DeltaMerge/tests/bench_vector_index_utils.h b/dbms/src/Storages/DeltaMerge/tests/bench_vector_index_utils.h new file mode 100644 index 00000000000..9203b059d3a --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/tests/bench_vector_index_utils.h @@ -0,0 +1,178 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS,n +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace DB::DM::bench +{ + +/** + * @brief Compatible with datasets on ANN-Benchmark: + * https://github.com/erikbern/ann-benchmarks + */ +class Dataset +{ +public: + explicit Dataset(std::string_view file_name) + { + auto dataset_directory = std::filesystem::path(__FILE__).parent_path().string() + "/bench_dataset"; + auto dataset_path = fmt::format("{}/{}", dataset_directory, file_name); + + if (!std::filesystem::exists(dataset_path)) + { + throw Exception(fmt::format( + "Benchmark cannot run because dataset file {} not found. See {}/README.md for setup instructions.", + dataset_path, + dataset_directory)); + } + + auto file = HighFive::File(dataset_path, HighFive::File::ReadOnly); + + auto dataset_train = file.getDataSet("train"); + dataset_train.read(data_train); + + auto dataset_test = file.getDataSet("test"); + dataset_test.read(data_test); + } + + virtual ~Dataset() = default; + + virtual UInt32 dimension() const = 0; + + virtual tipb::VectorDistanceMetric distanceMetric() const = 0; + +public: + MutableColumnPtr buildDataTrainColumn(std::optional max_rows = std::nullopt) const + { + auto vec_column = ColumnArray::create(ColumnFloat32::create()); + size_t rows = data_train.size(); + if (max_rows.has_value()) + rows = std::min(rows, *max_rows); + for (size_t i = 0; i < rows; ++i) + { + const auto & row = data_train[i]; + vec_column->insertData(reinterpret_cast(row.data()), row.size() * sizeof(Float32)); + } + return vec_column; + } + + size_t dataTestSize() const { return data_test.size(); } + + const std::vector & dataTestAt(size_t index) const { return data_test.at(index); } + + TiDB::VectorIndexDefinitionPtr createIndexDef(tipb::VectorIndexKind kind) const + { + return std::make_shared(TiDB::VectorIndexDefinition{ + .kind = kind, + .dimension = dimension(), + .distance_metric = distanceMetric(), + }); + } + +protected: + std::vector> data_train; + std::vector> data_test; +}; + +class DatasetMnist : public Dataset +{ +public: + DatasetMnist() + : Dataset("fashion-mnist-784-euclidean.hdf5") + { + RUNTIME_CHECK(data_train[0].size() == dimension()); + RUNTIME_CHECK(data_test[0].size() == dimension()); + } + + UInt32 dimension() const override { return 784; } + + tipb::VectorDistanceMetric distanceMetric() const override { return tipb::VectorDistanceMetric::L2; } + + static const DatasetMnist & get() + { + static DatasetMnist dataset; + return dataset; + } +}; + +class VectorIndexBenchUtils +{ +public: + template + static void saveVectorIndex( + std::string_view index_path, + const Dataset & dataset, + std::optional max_rows = std::nullopt) + { + Poco::File(index_path.data()).createDirectories(); + + auto train_data = dataset.buildDataTrainColumn(max_rows); + auto index_def = dataset.createIndexDef(Builder::kind()); + auto builder = std::make_unique(index_def); + builder->addBlock(*train_data, nullptr, []() { return true; }); + builder->save(index_path); + } + + template + static auto viewVectorIndex(std::string_view index_path, const Dataset & dataset) + { + auto index_view_props = dtpb::VectorIndexFileProps(); + index_view_props.set_index_kind(tipb::VectorIndexKind_Name(Viewer::kind())); + index_view_props.set_dimensions(dataset.dimension()); + index_view_props.set_distance_metric(tipb::VectorDistanceMetric_Name(dataset.distanceMetric())); + return Viewer::view(index_view_props, index_path); + } + + static auto queryTopK( + VectorIndexViewerPtr viewer, + const std::vector & ref, + UInt32 top_k, + std::optional> state = std::nullopt) + { + if (state.has_value()) + state->get().PauseTiming(); + + auto ann_query_info = std::make_shared(); + auto distance_metric = tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC; + tipb::VectorDistanceMetric_Parse(viewer->file_props.distance_metric(), &distance_metric); + ann_query_info->set_distance_metric(distance_metric); + ann_query_info->set_top_k(top_k); + ann_query_info->set_ref_vec_f32(DB::DM::tests::VectorIndexTestUtils::encodeVectorFloat32(ref)); + + auto filter = BitmapFilterView::createWithFilter(viewer->size(), true); + + if (state.has_value()) + state->get().ResumeTiming(); + + return viewer->search(ann_query_info, filter); + } +}; + + +} // namespace DB::DM::bench diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp index 9dcbbfea0c1..aaf3cce6e9f 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp @@ -220,7 +220,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -245,7 +245,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -270,7 +270,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -323,7 +323,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -348,7 +348,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -373,7 +373,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -407,7 +407,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -432,7 +432,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -467,7 +467,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(3, true), 0, 3)) + .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -530,7 +530,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(5, true), 0, 5)) + .setBitmapFilter(BitmapFilterView::createWithFilter(5, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -598,7 +598,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(6, true), 0, 6)) + .setBitmapFilter(BitmapFilterView::createWithFilter(6, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -623,7 +623,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(6, true), 0, 6)) + .setBitmapFilter(BitmapFilterView::createWithFilter(6, true)) .tryBuildWithVectorIndex( dm_file, read_cols, @@ -648,7 +648,7 @@ try DMFileBlockInputStreamBuilder builder(dbContext()); auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info)) - .setBitmapFilter(BitmapFilterView(std::make_shared(6, true), 0, 6)) + .setBitmapFilter(BitmapFilterView::createWithFilter(6, true)) .tryBuildWithVectorIndex( dm_file, read_cols,