diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e70f94c7..e86d2545 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,7 +8,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 @@ -19,6 +19,7 @@ jobs: run: python -m pip install . - name: Test + timeout-minutes: 15 run: python -m unittest discover -v --start-directory python_bindings/tests --pattern "*_test*.py" test_cpp: @@ -52,6 +53,7 @@ jobs: shell: bash - name: Test + timeout-minutes: 15 run: | cd build if [ "$RUNNER_OS" == "Windows" ]; then @@ -59,6 +61,8 @@ jobs: fi ./searchKnnCloserFirst_test ./searchKnnWithFilter_test + ./multiThreadLoad_test + ./multiThread_replace_test ./test_updates ./test_updates update shell: bash diff --git a/.gitignore b/.gitignore index a338107c..48f74604 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ var/ .idea/ .vscode/ .vs/ +**.DS_Store diff --git a/CMakeLists.txt b/CMakeLists.txt index e42d6cee..de951171 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,12 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) add_executable(searchKnnWithFilter_test examples/searchKnnWithFilter_test.cpp) target_link_libraries(searchKnnWithFilter_test hnswlib) + add_executable(multiThreadLoad_test examples/multiThreadLoad_test.cpp) + target_link_libraries(multiThreadLoad_test hnswlib) + + add_executable(multiThread_replace_test examples/multiThread_replace_test.cpp) + target_link_libraries(multiThread_replace_test hnswlib) + add_executable(main main.cpp sift_1b.cpp) target_link_libraries(main hnswlib) endif() diff --git a/README.md b/README.md index c86e4391..c0b0dbcc 100644 --- a/README.md +++ b/README.md @@ -54,19 +54,22 @@ For other spaces use the nmslib library https://github.com/nmslib/nmslib. * `hnswlib.Index(space, dim)` creates a non-initialized index an HNSW in space `space` with integer dimension `dim`. `hnswlib.Index` methods: -* `init_index(max_elements, M = 16, ef_construction = 200, random_seed = 100)` initializes the index from with no elements. +* `init_index(max_elements, M = 16, ef_construction = 200, random_seed = 100, allow_replace_deleted = False)` initializes the index from with no elements. * `max_elements` defines the maximum number of elements that can be stored in the structure(can be increased/shrunk). * `ef_construction` defines a construction time/accuracy trade-off (see [ALGO_PARAMS.md](ALGO_PARAMS.md)). * `M` defines tha maximum number of outgoing connections in the graph ([ALGO_PARAMS.md](ALGO_PARAMS.md)). + * `allow_replace_deleted` enables replacing of deleted elements with new added ones. -* `add_items(data, ids, num_threads = -1)` - inserts the `data`(numpy array of vectors, shape:`N*dim`) into the structure. +* `add_items(data, ids, num_threads = -1, replace_deleted = False)` - inserts the `data`(numpy array of vectors, shape:`N*dim`) into the structure. * `num_threads` sets the number of cpu threads to use (-1 means use default). * `ids` are optional N-size numpy array of integer labels for all elements in `data`. - If index already has the elements with the same labels, their features will be updated. Note that update procedure is slower than insertion of a new element, but more memory- and query-efficient. + * `replace_deleted` replaces deleted elements. Note it allows to save memory. + - to use it `init_index` should be called with `allow_replace_deleted=True` * Thread-safe with other `add_items` calls, but not with `knn_query`. * `mark_deleted(label)` - marks the element as deleted, so it will be omitted from search results. Throws an exception if it is already deleted. -* + * `unmark_deleted(label)` - unmarks the element as deleted, so it will be not be omitted from search results. * `resize_index(new_size)` - changes the maximum capacity of the index. Not thread safe with `add_items` and `knn_query`. @@ -74,13 +77,15 @@ For other spaces use the nmslib library https://github.com/nmslib/nmslib. * `set_ef(ef)` - sets the query time accuracy/speed trade-off, defined by the `ef` parameter ( [ALGO_PARAMS.md](ALGO_PARAMS.md)). Note that the parameter is currently not saved along with the index, so you need to set it manually after loading. -* `knn_query(data, k = 1, num_threads = -1)` make a batch query for `k` closest elements for each element of the +* `knn_query(data, k = 1, num_threads = -1, filter = None)` make a batch query for `k` closest elements for each element of the * `data` (shape:`N*dim`). Returns a numpy array of (shape:`N*k`). * `num_threads` sets the number of cpu threads to use (-1 means use default). + * `filter` filters elements by its labels, returns elements with allowed ids * Thread-safe with other `knn_query` calls, but not with `add_items`. -* `load_index(path_to_index, max_elements = 0)` loads the index from persistence to the uninitialized index. +* `load_index(path_to_index, max_elements = 0, allow_replace_deleted = False)` loads the index from persistence to the uninitialized index. * `max_elements`(optional) resets the maximum number of elements in the structure. + * `allow_replace_deleted` specifies whether the index being loaded has enabled replacing of deleted elements. * `save_index(path_to_index)` saves the index from persistence. @@ -142,7 +147,7 @@ p.add_items(data, ids) # Controlling the recall by setting ef: p.set_ef(50) # ef should always be > k -# Query dataset, k - number of closest elements (returns 2 numpy arrays) +# Query dataset, k - number of the closest elements (returns 2 numpy arrays) labels, distances = p.knn_query(data, k = 1) # Index objects support pickling @@ -155,7 +160,6 @@ print(f"Parameters passed to constructor: space={p_copy.space}, dim={p_copy.dim print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}") print(f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}") print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}") - ``` An example with updates after serialization/deserialization: @@ -196,7 +200,6 @@ p.set_ef(10) # By default using all available cores p.set_num_threads(4) - print("Adding first batch of %d elements" % (len(data1))) p.add_items(data1) @@ -226,6 +229,104 @@ labels, distances = p.knn_query(data, k=1) print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") ``` +An example with a filter: +```python +import hnswlib +import numpy as np + +dim = 16 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +print("Adding %d elements" % (len(data))) +# Added elements will have consecutive ids +hnsw_index.add_items(data, ids=np.arange(num_elements)) + +print("Querying only even elements") +# Define filter function that allows only even ids +filter_function = lambda idx: idx%2 == 0 +# Query the elements for themselves and search only for even elements: +labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) +# labels contain only elements with even id +``` + +An example with replacing of deleted elements: +```python +import hnswlib +import numpy as np + +dim = 16 +num_elements = 1_000 +max_num_elements = 2 * num_elements + +# Generating sample data +labels1 = np.arange(0, num_elements) +data1 = np.float32(np.random.random((num_elements, dim))) # batch 1 +labels2 = np.arange(num_elements, 2 * num_elements) +data2 = np.float32(np.random.random((num_elements, dim))) # batch 2 +labels3 = np.arange(2 * num_elements, 3 * num_elements) +data3 = np.float32(np.random.random((num_elements, dim))) # batch 3 + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +# Enable replacing of deleted elements +hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +# Add batch 1 and 2 data +hnsw_index.add_items(data1, labels1) +hnsw_index.add_items(data2, labels2) # Note: maximum number of elements is reached + +# Delete data of batch 2 +for label in labels2: + hnsw_index.mark_deleted(label) + +# Replace deleted elements +# Maximum number of elements is reached therefore we cannot add new items, +# but we can replace the deleted ones by using replace_deleted=True +hnsw_index.add_items(data3, labels3, replace_deleted=True) +# hnsw_index contains the data of batch 1 and batch 3 only +``` + ### Bindings installation You can install from sources: diff --git a/examples/multiThreadLoad_test.cpp b/examples/multiThreadLoad_test.cpp new file mode 100644 index 00000000..a713b2ba --- /dev/null +++ b/examples/multiThreadLoad_test.cpp @@ -0,0 +1,140 @@ +#include "../hnswlib/hnswlib.h" +#include +#include + + +int main() { + std::cout << "Running multithread load test" << std::endl; + int d = 16; + int max_elements = 1000; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + + hnswlib::L2Space space(d); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * max_elements); + + std::cout << "Building index" << std::endl; + int num_threads = 40; + int num_labels = 10; + + int num_iterations = 10; + int start_label = 0; + + // run threads that will add elements to the index + // about 7 threads (the number depends on num_threads and num_labels) + // will add/update element with the same label simultaneously + while (true) { + // add elements by batches + std::uniform_int_distribution<> distrib_int(start_label, start_label + num_labels - 1); + std::vector threads; + for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + threads.push_back( + std::thread( + [&] { + for (int iter = 0; iter < num_iterations; iter++) { + std::vector data(d); + hnswlib::labeltype label = distrib_int(rng); + for (int i = 0; i < d; i++) { + data[i] = distrib_real(rng); + } + alg_hnsw->addPoint(data.data(), label); + } + } + ) + ); + } + for (auto &thread : threads) { + thread.join(); + } + if (alg_hnsw->cur_element_count > max_elements - num_labels) { + break; + } + start_label += num_labels; + } + + // insert remaining elements if needed + for (hnswlib::labeltype label = 0; label < max_elements; label++) { + auto search = alg_hnsw->label_lookup_.find(label); + if (search == alg_hnsw->label_lookup_.end()) { + std::cout << "Adding " << label << std::endl; + std::vector data(d); + for (int i = 0; i < d; i++) { + data[i] = distrib_real(rng); + } + alg_hnsw->addPoint(data.data(), label); + } + } + + std::cout << "Index is created" << std::endl; + + bool stop_threads = false; + std::vector threads; + + // create threads that will do markDeleted and unmarkDeleted of random elements + // each thread works with specific range of labels + std::cout << "Starting markDeleted and unmarkDeleted threads" << std::endl; + num_threads = 20; + int chunk_size = max_elements / num_threads; + for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + threads.push_back( + std::thread( + [&, thread_id] { + std::uniform_int_distribution<> distrib_int(0, chunk_size - 1); + int start_id = thread_id * chunk_size; + std::vector marked_deleted(chunk_size); + while (!stop_threads) { + int id = distrib_int(rng); + hnswlib::labeltype label = start_id + id; + if (marked_deleted[id]) { + alg_hnsw->unmarkDelete(label); + marked_deleted[id] = false; + } else { + alg_hnsw->markDelete(label); + marked_deleted[id] = true; + } + } + } + ) + ); + } + + // create threads that will add and update random elements + std::cout << "Starting add and update elements threads" << std::endl; + num_threads = 20; + std::uniform_int_distribution<> distrib_int_add(max_elements, 2 * max_elements - 1); + for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + threads.push_back( + std::thread( + [&] { + std::vector data(d); + while (!stop_threads) { + hnswlib::labeltype label = distrib_int_add(rng); + for (int i = 0; i < d; i++) { + data[i] = distrib_real(rng); + } + alg_hnsw->addPoint(data.data(), label); + std::vector data = alg_hnsw->getDataByLabel(label); + float max_val = *max_element(data.begin(), data.end()); + // never happens but prevents compiler from deleting unused code + if (max_val > 10) { + throw std::runtime_error("Unexpected value in data"); + } + } + } + ) + ); + } + + std::cout << "Sleep and continue operations with index" << std::endl; + int sleep_ms = 60 * 1000; + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); + stop_threads = true; + for (auto &thread : threads) { + thread.join(); + } + + std::cout << "Finish" << std::endl; + return 0; +} diff --git a/examples/multiThread_replace_test.cpp b/examples/multiThread_replace_test.cpp new file mode 100644 index 00000000..83ed2826 --- /dev/null +++ b/examples/multiThread_replace_test.cpp @@ -0,0 +1,121 @@ +#include "../hnswlib/hnswlib.h" +#include +#include + + +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if (id >= end) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } +} + + +int main() { + std::cout << "Running multithread load test" << std::endl; + int d = 16; + int num_elements = 1000; + int max_elements = 2 * num_elements; + int num_threads = 50; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + + hnswlib::L2Space space(d); + + // generate batch1 and batch2 data + float* batch1 = new float[d * max_elements]; + for (int i = 0; i < d * max_elements; i++) { + batch1[i] = distrib_real(rng); + } + float* batch2 = new float[d * num_elements]; + for (int i = 0; i < d * num_elements; i++) { + batch2[i] = distrib_real(rng); + } + + // generate random labels to delete them from index + std::vector rand_labels(max_elements); + for (int i = 0; i < max_elements; i++) { + rand_labels[i] = i; + } + std::shuffle(rand_labels.begin(), rand_labels.end(), rng); + + int iter = 0; + while (iter < 200) { + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, 16, 200, 123, true); + + // add batch1 data + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw->addPoint((void*)(batch1 + d * row), row); + }); + + // delete half random elements of batch1 data + for (int i = 0; i < num_elements; i++) { + alg_hnsw->markDelete(rand_labels[i]); + } + + // replace deleted elements with batch2 data + ParallelFor(0, num_elements, num_threads, [&](size_t row, size_t threadId) { + int label = rand_labels[row] + max_elements; + alg_hnsw->addPoint((void*)(batch2 + d * row), label, true); + }); + + iter += 1; + + delete alg_hnsw; + } + + std::cout << "Finish" << std::endl; + + delete[] batch1; + delete[] batch2; + return 0; +} diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 21130090..30b33ae9 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -61,7 +61,7 @@ class BruteforceSearch : public AlgorithmInterface { } - void addPoint(const void *datapoint, labeltype label) { + void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) { int idx; { std::unique_lock lock(index_lock); diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 25995134..7f34e62b 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -16,11 +16,11 @@ typedef unsigned int linklistsizeint; template class HierarchicalNSW : public AlgorithmInterface { public: - static const tableint max_update_element_locks = 65536; + static const tableint MAX_LABEL_OPERATION_LOCKS = 65536; static const unsigned char DELETE_MARK = 0x01; size_t max_elements_{0}; - size_t cur_element_count{0}; + mutable std::atomic cur_element_count{0}; // current number of elements size_t size_data_per_element_{0}; size_t size_links_per_element_{0}; mutable std::atomic num_deleted_{0}; // number of deleted elements @@ -35,13 +35,10 @@ class HierarchicalNSW : public AlgorithmInterface { VisitedListPool *visited_list_pool_{nullptr}; - // Locks to prevent race condition during update/insert of an element at same time. - // Note: Locks for additions can also be used to prevent this race condition - // if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel. - std::vector link_list_update_locks_; + // Locks operations with element by label value + mutable std::vector label_op_locks_; std::mutex global; - std::mutex cur_element_count_guard_; std::vector link_list_locks_; tableint enterpoint_node_{0}; @@ -57,7 +54,8 @@ class HierarchicalNSW : public AlgorithmInterface { DISTFUNC fstdistfunc_; void *dist_func_param_{nullptr}; - std::mutex label_lookup_lock; + + mutable std::mutex label_lookup_lock; // lock for label_lookup_ std::unordered_map label_lookup_; std::default_random_engine level_generator_; @@ -66,6 +64,11 @@ class HierarchicalNSW : public AlgorithmInterface { mutable std::atomic metric_distance_computations{0}; mutable std::atomic metric_hops{0}; + bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions + + std::mutex deleted_elements_lock; // lock for deleted_elements + std::unordered_set deleted_elements; // contains internal ids of deleted elements + HierarchicalNSW(SpaceInterface *s) { } @@ -75,7 +78,9 @@ class HierarchicalNSW : public AlgorithmInterface { SpaceInterface *s, const std::string &location, bool nmslib = false, - size_t max_elements = 0) { + size_t max_elements = 0, + bool allow_replace_deleted = false) + : allow_replace_deleted_(allow_replace_deleted) { loadIndex(location, s, max_elements); } @@ -85,10 +90,12 @@ class HierarchicalNSW : public AlgorithmInterface { size_t max_elements, size_t M = 16, size_t ef_construction = 200, - size_t random_seed = 100) + size_t random_seed = 100, + bool allow_replace_deleted = false) : link_list_locks_(max_elements), - link_list_update_locks_(max_update_element_locks), - element_levels_(max_elements) { + label_op_locks_(MAX_LABEL_OPERATION_LOCKS), + element_levels_(max_elements), + allow_replace_deleted_(allow_replace_deleted) { max_elements_ = max_elements; num_deleted_ = 0; data_size_ = s->get_data_size(); @@ -154,6 +161,13 @@ class HierarchicalNSW : public AlgorithmInterface { } + inline std::mutex& getLabelOpMutex(labeltype label) const { + // calculate hash + size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1); + return label_op_locks_[lock_id]; + } + + inline labeltype getExternalLabel(tableint internal_id) const { labeltype return_label; memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); @@ -437,6 +451,12 @@ class HierarchicalNSW : public AlgorithmInterface { tableint next_closest_entry_point = selectedNeighbors.back(); { + // lock only during the update + // because during the addition the lock for cur_c is already acquired + std::unique_lock lock(link_list_locks_[cur_c], std::defer_lock); + if (isUpdate) { + lock.lock(); + } linklistsizeint *ll_cur; if (level == 0) ll_cur = get_linklist0(cur_c); @@ -664,7 +684,7 @@ class HierarchicalNSW : public AlgorithmInterface { size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); std::vector(max_elements).swap(link_list_locks_); - std::vector(max_update_element_locks).swap(link_list_update_locks_); + std::vector(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); visited_list_pool_ = new VisitedListPool(1, max_elements); @@ -693,6 +713,7 @@ class HierarchicalNSW : public AlgorithmInterface { for (size_t i = 0; i < cur_element_count; i++) { if (isMarkedDeleted(i)) { num_deleted_ += 1; + if (allow_replace_deleted_) deleted_elements.insert(i); } } @@ -704,14 +725,18 @@ class HierarchicalNSW : public AlgorithmInterface { template std::vector getDataByLabel(labeltype label) const { - tableint label_internal; + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { throw std::runtime_error("Label not found"); } - label_internal = search->second; + tableint internalId = search->second; + lock_table.unlock(); - char* data_ptrv = getDataByInternalId(label_internal); + char* data_ptrv = getDataByInternalId(internalId); size_t dim = *((size_t *) dist_func_param_); std::vector data; data_t* data_ptr = (data_t*) data_ptrv; @@ -723,66 +748,90 @@ class HierarchicalNSW : public AlgorithmInterface { } - /** - * Marks an element with the given label deleted, does NOT really change the current graph. - */ + /* + * Marks an element with the given label deleted, does NOT really change the current graph. + */ void markDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end()) { throw std::runtime_error("Label not found"); } tableint internalId = search->second; + lock_table.unlock(); + markDeletedInternal(internalId); } - /** - * Uses the last 16 bits of the memory for the linked list size to store the mark, - * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. - */ + /* + * Uses the last 16 bits of the memory for the linked list size to store the mark, + * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. + */ void markDeletedInternal(tableint internalId) { assert(internalId < cur_element_count); if (!isMarkedDeleted(internalId)) { unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; *ll_cur |= DELETE_MARK; num_deleted_ += 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.insert(internalId); + } } else { throw std::runtime_error("The requested to delete element is already deleted"); } } - /** - * Remove the deleted mark of the node, does NOT really change the current graph. - */ + /* + * Removes the deleted mark of the node, does NOT really change the current graph. + * + * Note: the method is not safe to use when replacement of deleted elements is enabled, + * because elements marked as deleted can be completely removed by addPoint + */ void unmarkDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end()) { throw std::runtime_error("Label not found"); } tableint internalId = search->second; + lock_table.unlock(); + unmarkDeletedInternal(internalId); } - /** - * Remove the deleted mark of the node. - */ + + /* + * Remove the deleted mark of the node. + */ void unmarkDeletedInternal(tableint internalId) { assert(internalId < cur_element_count); if (isMarkedDeleted(internalId)) { unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2; *ll_cur &= ~DELETE_MARK; num_deleted_ -= 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.erase(internalId); + } } else { throw std::runtime_error("The requested to undelete element is not deleted"); } } - /** - * Checks the first 16 bits of the memory to see if the element is marked deleted. - */ + /* + * Checks the first 16 bits of the memory to see if the element is marked deleted. + */ bool isMarkedDeleted(tableint internalId) const { unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2; return *ll_cur & DELETE_MARK; @@ -799,11 +848,48 @@ class HierarchicalNSW : public AlgorithmInterface { } - /** - * Adds point. Updates the point if it is already in the index + /* + * Adds point. Updates the point if it is already in the index. + * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point */ - void addPoint(const void *data_point, labeltype label) { - addPoint(data_point, label, -1); + void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) { + if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { + throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); + } + + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + if (!replace_deleted) { + addPoint(data_point, label, -1); + return; + } + // check if there is vacant place + tableint internal_id_replaced; + std::unique_lock lock_deleted_elements(deleted_elements_lock); + bool is_vacant_place = !deleted_elements.empty(); + if (is_vacant_place) { + internal_id_replaced = *deleted_elements.begin(); + deleted_elements.erase(internal_id_replaced); + } + lock_deleted_elements.unlock(); + + // if there is no vacant place then add or update point + // else add point to vacant place + if (!is_vacant_place) { + addPoint(data_point, label, -1); + } else { + // we assume that there are no concurrent operations on deleted element + labeltype label_replaced = getExternalLabel(internal_id_replaced); + setExternalLabel(internal_id_replaced, label); + + std::unique_lock lock_table(label_lookup_lock); + label_lookup_.erase(label_replaced); + label_lookup_[label] = internal_id_replaced; + lock_table.unlock(); + + unmarkDeletedInternal(internal_id_replaced); + updatePoint(data_point, internal_id_replaced, 1.0); + } } @@ -970,13 +1056,16 @@ class HierarchicalNSW : public AlgorithmInterface { { // Checking if the element with the same label already exists // if so, updating it *instead* of creating a new element. - std::unique_lock templock_curr(cur_element_count_guard_); + std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search != label_lookup_.end()) { tableint existingInternalId = search->second; - templock_curr.unlock(); - - std::unique_lock lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]); + if (allow_replace_deleted_) { + if (isMarkedDeleted(existingInternalId)) { + throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled."); + } + } + lock_table.unlock(); if (isMarkedDeleted(existingInternalId)) { unmarkDeletedInternal(existingInternalId); @@ -995,8 +1084,6 @@ class HierarchicalNSW : public AlgorithmInterface { label_lookup_[label] = cur_c; } - // Take update lock to prevent race conditions on an element with insertion/update at the same time. - std::unique_lock lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]); std::unique_lock lock_el(link_list_locks_[cur_c]); int curlevel = getRandomLevel(mult_); if (level > 0) diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 72c955dc..fb7118fa 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -158,7 +158,7 @@ class SpaceInterface { template class AlgorithmInterface { public: - virtual void addPoint(const void *datapoint, labeltype label) = 0; + virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0; virtual std::priority_queue> searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 3da8dbba..3196a228 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -193,12 +193,13 @@ class Index { size_t maxElements, size_t M, size_t efConstruction, - size_t random_seed) { + size_t random_seed, + bool allow_replace_deleted) { if (appr_alg) { throw std::runtime_error("The index is already initiated."); } cur_l = 0; - appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed); + appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed, allow_replace_deleted); index_inited = true; ep_added = false; appr_alg->ef_ = default_ef; @@ -223,12 +224,12 @@ class Index { } - void loadIndex(const std::string &path_to_index, size_t max_elements) { + void loadIndex(const std::string &path_to_index, size_t max_elements, bool allow_replace_deleted) { if (appr_alg) { std::cerr << "Warning: Calling load_index for an already inited index. Old index is being deallocated." << std::endl; delete appr_alg; } - appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements); + appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements, allow_replace_deleted); cur_l = appr_alg->cur_element_count; index_inited = true; } @@ -244,7 +245,7 @@ class Index { } - void addItems(py::object input, py::object ids_ = py::none(), int num_threads = -1) { + void addItems(py::object input, py::object ids_ = py::none(), int num_threads = -1, bool replace_deleted = false) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); if (num_threads <= 0) @@ -273,7 +274,7 @@ class Index { normalize_vector(vector_data, norm_array.data()); vector_data = norm_array.data(); } - appr_alg->addPoint((void*)vector_data, (size_t)id); + appr_alg->addPoint((void*)vector_data, (size_t)id, replace_deleted); start = 1; ep_added = true; } @@ -282,7 +283,7 @@ class Index { if (normalize == false) { ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { size_t id = ids.size() ? ids.at(row) : (cur_l + row); - appr_alg->addPoint((void*)items.data(row), (size_t)id); + appr_alg->addPoint((void*)items.data(row), (size_t)id, replace_deleted); }); } else { std::vector norm_array(num_threads * dim); @@ -292,7 +293,7 @@ class Index { normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); size_t id = ids.size() ? ids.at(row) : (cur_l + row); - appr_alg->addPoint((void*)(norm_array.data() + start_idx), (size_t)id); + appr_alg->addPoint((void*)(norm_array.data() + start_idx), (size_t)id, replace_deleted); }); } cur_l += rows; @@ -400,7 +401,7 @@ class Index { return py::dict( "offset_level0"_a = appr_alg->offsetLevel0_, "max_elements"_a = appr_alg->max_elements_, - "cur_element_count"_a = appr_alg->cur_element_count, + "cur_element_count"_a = (size_t)appr_alg->cur_element_count, "size_data_per_element"_a = appr_alg->size_data_per_element_, "label_offset"_a = appr_alg->label_offset_, "offset_data"_a = appr_alg->offsetData_, @@ -414,6 +415,7 @@ class Index { "ef"_a = appr_alg->ef_, "has_deletions"_a = (bool)appr_alg->num_deleted_, "size_links_per_element"_a = appr_alg->size_links_per_element_, + "allow_replace_deleted"_a = appr_alg->allow_replace_deleted_, "label_lookup_external"_a = py::array_t( { appr_alg->label_lookup_.size() }, // shape @@ -576,12 +578,19 @@ class Index { } // process deleted elements + bool allow_replace_deleted = false; + if (d.contains("allow_replace_deleted")) { + allow_replace_deleted = d["allow_replace_deleted"].cast(); + } + appr_alg->allow_replace_deleted_= allow_replace_deleted; + appr_alg->num_deleted_ = 0; bool has_deletions = d["has_deletions"].cast(); if (has_deletions) { for (size_t i = 0; i < appr_alg->cur_element_count; i++) { if (appr_alg->isMarkedDeleted(i)) { appr_alg->num_deleted_ += 1; + if (allow_replace_deleted) appr_alg->deleted_elements.insert(i); } } } @@ -871,15 +880,35 @@ PYBIND11_PLUGIN(hnswlib) { /* WARNING: Index::createFromIndex is not thread-safe with Index::addItems */ .def(py::init(&Index::createFromIndex), py::arg("index")) .def(py::init(), py::arg("space"), py::arg("dim")) - .def("init_index", &Index::init_new_index, py::arg("max_elements"), py::arg("M") = 16, py::arg("ef_construction") = 200, py::arg("random_seed") = 100) - .def("knn_query", &Index::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1, py::arg("filter") = py::none()) - .def("add_items", &Index::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads") = -1) + .def("init_index", + &Index::init_new_index, + py::arg("max_elements"), + py::arg("M") = 16, + py::arg("ef_construction") = 200, + py::arg("random_seed") = 100, + py::arg("allow_replace_deleted") = false) + .def("knn_query", + &Index::knnQuery_return_numpy, + py::arg("data"), + py::arg("k") = 1, + py::arg("num_threads") = -1, + py::arg("filter") = py::none()) + .def("add_items", + &Index::addItems, + py::arg("data"), + py::arg("ids") = py::none(), + py::arg("num_threads") = -1, + py::arg("replace_deleted") = false) .def("get_items", &Index::getDataReturnList, py::arg("ids") = py::none()) .def("get_ids_list", &Index::getIdsList) .def("set_ef", &Index::set_ef, py::arg("ef")) .def("set_num_threads", &Index::set_num_threads, py::arg("num_threads")) .def("save_index", &Index::saveIndex, py::arg("path_to_index")) - .def("load_index", &Index::loadIndex, py::arg("path_to_index"), py::arg("max_elements") = 0) + .def("load_index", + &Index::loadIndex, + py::arg("path_to_index"), + py::arg("max_elements") = 0, + py::arg("allow_replace_deleted") = false) .def("mark_deleted", &Index::markDeleted, py::arg("label")) .def("unmark_deleted", &Index::unmarkDeleted, py::arg("label")) .def("resize_index", &Index::resizeIndex, py::arg("new_size")) @@ -901,7 +930,7 @@ PYBIND11_PLUGIN(hnswlib) { return index.index_inited ? index.appr_alg->max_elements_ : 0; }) .def_property_readonly("element_count", [](const Index & index) { - return index.index_inited ? index.appr_alg->cur_element_count : 0; + return index.index_inited ? (size_t)index.appr_alg->cur_element_count : 0; }) .def_property_readonly("ef_construction", [](const Index & index) { return index.index_inited ? index.appr_alg->ef_construction_ : 0; diff --git a/python_bindings/tests/bindings_test_filter.py b/python_bindings/tests/bindings_test_filter.py index a0715d7c..a798e02f 100644 --- a/python_bindings/tests/bindings_test_filter.py +++ b/python_bindings/tests/bindings_test_filter.py @@ -49,7 +49,7 @@ def testRandomSelf(self): filter_function = lambda id: id%2 == 0 labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5, 3) - # Verify that there are onle even elements: + # Verify that there are only even elements: self.assertTrue(np.max(np.mod(labels, 2)) == 0) labels, distances = bf_index.knn_query(data, k=1, filter=filter_function) diff --git a/python_bindings/tests/bindings_test_labels.py b/python_bindings/tests/bindings_test_labels.py index 2b091371..524a24d5 100644 --- a/python_bindings/tests/bindings_test_labels.py +++ b/python_bindings/tests/bindings_test_labels.py @@ -95,19 +95,20 @@ def testRandomSelf(self): # Delete data1 labels1_deleted, _ = p.knn_query(data1, k=1) - - for l in labels1_deleted: - p.mark_deleted(l[0]) + # delete probable duplicates from nearest neighbors + labels1_deleted_no_dup = set(labels1_deleted.flatten()) + for l in labels1_deleted_no_dup: + p.mark_deleted(l) labels2, _ = p.knn_query(data2, k=1) items = p.get_items(labels2) diff_with_gt_labels = np.mean(np.abs(data2-items)) - self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) # console + self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) labels1_after, _ = p.knn_query(data1, k=1) for la in labels1_after: - for lb in labels1_deleted: - if la[0] == lb[0]: - self.assertTrue(False) + if la[0] in labels1_deleted_no_dup: + print(f"Found deleted label {la[0]} during knn search") + self.assertTrue(False) print("All the data in data1 are removed") # Checking saving/loading index with elements marked as deleted @@ -119,13 +120,13 @@ def testRandomSelf(self): labels1_after, _ = p.knn_query(data1, k=1) for la in labels1_after: - for lb in labels1_deleted: - if la[0] == lb[0]: - self.assertTrue(False) + if la[0] in labels1_deleted_no_dup: + print(f"Found deleted label {la[0]} during knn search after index loading") + self.assertTrue(False) # Unmark deleted data - for l in labels1_deleted: - p.unmark_deleted(l[0]) + for l in labels1_deleted_no_dup: + p.unmark_deleted(l) labels_restored, _ = p.knn_query(data1, k=1) self.assertAlmostEqual(np.mean(labels_restored.reshape(-1) == np.arange(len(data1))), 1.0, 3) print("All the data in data1 are restored") diff --git a/python_bindings/tests/bindings_test_recall.py b/python_bindings/tests/bindings_test_recall.py index 55a970d1..2190ba45 100644 --- a/python_bindings/tests/bindings_test_recall.py +++ b/python_bindings/tests/bindings_test_recall.py @@ -40,7 +40,7 @@ def testRandomSelf(self): # Set number of threads used during batch search/construction in hnsw # By default using all available cores - hnsw_index.set_num_threads(1) + hnsw_index.set_num_threads(4) print("Adding batch of %d elements" % (len(data))) hnsw_index.add_items(data) diff --git a/python_bindings/tests/bindings_test_replace.py b/python_bindings/tests/bindings_test_replace.py new file mode 100644 index 00000000..80003a3a --- /dev/null +++ b/python_bindings/tests/bindings_test_replace.py @@ -0,0 +1,245 @@ +import os +import pickle +import unittest + +import numpy as np + +import hnswlib + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + """ + Tests if replace of deleted elements works correctly + Tests serialization of the index with replaced elements + """ + dim = 16 + num_elements = 5000 + max_num_elements = 2 * num_elements + + recall_threshold = 0.98 + + # Generating sample data + print("Generating data") + # batch 1 + first_id = 0 + last_id = num_elements + labels1 = np.arange(first_id, last_id) + data1 = np.float32(np.random.random((num_elements, dim))) + # batch 2 + first_id += num_elements + last_id += num_elements + labels2 = np.arange(first_id, last_id) + data2 = np.float32(np.random.random((num_elements, dim))) + # batch 3 + first_id += num_elements + last_id += num_elements + labels3 = np.arange(first_id, last_id) + data3 = np.float32(np.random.random((num_elements, dim))) + # batch 4 + first_id += num_elements + last_id += num_elements + labels4 = np.arange(first_id, last_id) + data4 = np.float32(np.random.random((num_elements, dim))) + + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) + hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + + hnsw_index.set_ef(100) + hnsw_index.set_num_threads(4) + + # Add batch 1 and 2 + print("Adding batch 1") + hnsw_index.add_items(data1, labels1) + print("Adding batch 2") + hnsw_index.add_items(data2, labels2) # maximum number of elements is reached + + # Delete nearest neighbors of batch 2 + print("Deleting neighbors of batch 2") + labels2_deleted, _ = hnsw_index.knn_query(data2, k=1) + # delete probable duplicates from nearest neighbors + labels2_deleted_no_dup = set(labels2_deleted.flatten()) + num_duplicates = len(labels2_deleted) - len(labels2_deleted_no_dup) + for l in labels2_deleted_no_dup: + hnsw_index.mark_deleted(l) + labels1_found, _ = hnsw_index.knn_query(data1, k=1) + items = hnsw_index.get_items(labels1_found) + diff_with_gt_labels = np.mean(np.abs(data1 - items)) + self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) + + labels2_after, _ = hnsw_index.knn_query(data2, k=1) + for la in labels2_after: + if la[0] in labels2_deleted_no_dup: + print(f"Found deleted label {la[0]} during knn search") + self.assertTrue(False) + print("All the neighbors of data2 are removed") + + # Replace deleted elements + print("Inserting batch 3 by replacing deleted elements") + # Maximum number of elements is reached therefore we cannot add new items + # but we can replace the deleted ones + # Note: there may be less than num_elements elements. + # As we could delete less than num_elements because of duplicates + labels3_tr = labels3[0:labels3.shape[0] - num_duplicates] + data3_tr = data3[0:data3.shape[0] - num_duplicates] + hnsw_index.add_items(data3_tr, labels3_tr, replace_deleted=True) + + # After replacing, all labels should be retrievable + print("Checking that remaining labels are in index") + # Get remaining data from batch 1 and batch 2 after deletion of elements + remaining_labels = (set(labels1) | set(labels2)) - labels2_deleted_no_dup + remaining_labels_list = list(remaining_labels) + comb_data = np.concatenate((data1, data2), axis=0) + remaining_data = comb_data[remaining_labels_list] + + returned_items = hnsw_index.get_items(remaining_labels_list) + self.assertSequenceEqual(remaining_data.tolist(), returned_items) + + returned_items = hnsw_index.get_items(labels3_tr) + self.assertSequenceEqual(data3_tr.tolist(), returned_items) + + # Check index serialization + # Delete batch 3 + print("Deleting batch 3") + for l in labels3_tr: + hnsw_index.mark_deleted(l) + + # Save index + index_path = "index.bin" + print(f"Saving index to {index_path}") + hnsw_index.save_index(index_path) + del hnsw_index + + # Reinit and load the index + hnsw_index = hnswlib.Index(space='l2', dim=dim) # the space can be changed - keeps the data, alters the distance function. + hnsw_index.set_num_threads(4) + print(f"Loading index from {index_path}") + hnsw_index.load_index(index_path, max_elements=max_num_elements, allow_replace_deleted=True) + + # Insert batch 4 + print("Inserting batch 4 by replacing deleted elements") + labels4_tr = labels4[0:labels4.shape[0] - num_duplicates] + data4_tr = data4[0:data4.shape[0] - num_duplicates] + hnsw_index.add_items(data4_tr, labels4_tr, replace_deleted=True) + + # Check recall + print("Checking recall") + labels_found, _ = hnsw_index.knn_query(data4_tr, k=1) + recall = np.mean(labels_found.reshape(-1) == labels4_tr) + print(f"Recall for the 4 batch: {recall}") + self.assertGreater(recall, recall_threshold) + + # Delete batch 4 + print("Deleting batch 4") + for l in labels4_tr: + hnsw_index.mark_deleted(l) + + print("Testing pickle serialization") + hnsw_index_pckl = pickle.loads(pickle.dumps(hnsw_index)) + del hnsw_index + # Insert batch 3 + print("Inserting batch 3 by replacing deleted elements") + hnsw_index_pckl.add_items(data3_tr, labels3_tr, replace_deleted=True) + + # Check recall + print("Checking recall") + labels_found, _ = hnsw_index_pckl.knn_query(data3_tr, k=1) + recall = np.mean(labels_found.reshape(-1) == labels3_tr) + print(f"Recall for the 3 batch: {recall}") + self.assertGreater(recall, recall_threshold) + + os.remove(index_path) + + + def test_recall_degradation(self): + """ + Compares recall of the index with replaced elements and without + Measures recall degradation + """ + dim = 16 + num_elements = 10_000 + max_num_elements = 2 * num_elements + query_size = 1_000 + k = 100 + + recall_threshold = 0.98 + max_recall_diff = 0.02 + + # Generating sample data + print("Generating data") + # batch 1 + first_id = 0 + last_id = num_elements + labels1 = np.arange(first_id, last_id) + data1 = np.float32(np.random.random((num_elements, dim))) + # batch 2 + first_id += num_elements + last_id += num_elements + labels2 = np.arange(first_id, last_id) + data2 = np.float32(np.random.random((num_elements, dim))) + # batch 3 + first_id += num_elements + last_id += num_elements + labels3 = np.arange(first_id, last_id) + data3 = np.float32(np.random.random((num_elements, dim))) + # query to test recall + query_data = np.float32(np.random.random((query_size, dim))) + + # Declaring index + hnsw_index_no_replace = hnswlib.Index(space='l2', dim=dim) + hnsw_index_no_replace.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=False) + hnsw_index_with_replace = hnswlib.Index(space='l2', dim=dim) + hnsw_index_with_replace.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + + bf_index = hnswlib.BFIndex(space='l2', dim=dim) + bf_index.init_index(max_elements=max_num_elements) + + hnsw_index_no_replace.set_ef(100) + hnsw_index_no_replace.set_num_threads(50) + hnsw_index_with_replace.set_ef(100) + hnsw_index_with_replace.set_num_threads(50) + + # Add data + print("Adding data") + hnsw_index_with_replace.add_items(data1, labels1) + hnsw_index_with_replace.add_items(data2, labels2) # maximum number of elements is reached + bf_index.add_items(data1, labels1) + bf_index.add_items(data3, labels3) # maximum number of elements is reached + + for l in labels2: + hnsw_index_with_replace.mark_deleted(l) + hnsw_index_with_replace.add_items(data3, labels3, replace_deleted=True) + + hnsw_index_no_replace.add_items(data1, labels1) + hnsw_index_no_replace.add_items(data3, labels3) # maximum number of elements is reached + + # Query the elements and measure recall: + labels_hnsw_with_replace, _ = hnsw_index_with_replace.knn_query(query_data, k) + labels_hnsw_no_replace, _ = hnsw_index_no_replace.knn_query(query_data, k) + labels_bf, distances_bf = bf_index.knn_query(query_data, k) + + # Measure recall + correct_with_replace = 0 + correct_no_replace = 0 + for i in range(query_size): + for label in labels_hnsw_with_replace[i]: + for correct_label in labels_bf[i]: + if label == correct_label: + correct_with_replace += 1 + break + for label in labels_hnsw_no_replace[i]: + for correct_label in labels_bf[i]: + if label == correct_label: + correct_no_replace += 1 + break + + recall_with_replace = float(correct_with_replace) / (k*query_size) + recall_no_replace = float(correct_no_replace) / (k*query_size) + print("recall with replace:", recall_with_replace) + print("recall without replace:", recall_no_replace) + + recall_diff = abs(recall_with_replace - recall_with_replace) + + self.assertGreater(recall_no_replace, recall_threshold) + self.assertLess(recall_diff, max_recall_diff) diff --git a/python_bindings/tests/bindings_test_stress_mt_replace.py b/python_bindings/tests/bindings_test_stress_mt_replace.py new file mode 100644 index 00000000..8cd3e9bc --- /dev/null +++ b/python_bindings/tests/bindings_test_stress_mt_replace.py @@ -0,0 +1,68 @@ +import unittest + +import numpy as np + +import hnswlib + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + dim = 16 + num_elements = 1_000 + max_num_elements = 2 * num_elements + + # Generating sample data + # batch 1 + first_id = 0 + last_id = num_elements + labels1 = np.arange(first_id, last_id) + data1 = np.float32(np.random.random((num_elements, dim))) + # batch 2 + first_id += num_elements + last_id += num_elements + labels2 = np.arange(first_id, last_id) + data2 = np.float32(np.random.random((num_elements, dim))) + # batch 3 + first_id += num_elements + last_id += num_elements + labels3 = np.arange(first_id, last_id) + data3 = np.float32(np.random.random((num_elements, dim))) + + for _ in range(100): + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) + hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + + hnsw_index.set_ef(100) + hnsw_index.set_num_threads(50) + + # Add batch 1 and 2 + hnsw_index.add_items(data1, labels1) + hnsw_index.add_items(data2, labels2) # maximum number of elements is reached + + # Delete nearest neighbors of batch 2 + labels2_deleted, _ = hnsw_index.knn_query(data2, k=1) + labels2_deleted_flat = labels2_deleted.flatten() + # delete probable duplicates from nearest neighbors + labels2_deleted_no_dup = set(labels2_deleted_flat) + for l in labels2_deleted_no_dup: + hnsw_index.mark_deleted(l) + labels1_found, _ = hnsw_index.knn_query(data1, k=1) + items = hnsw_index.get_items(labels1_found) + diff_with_gt_labels = np.mean(np.abs(data1 - items)) + self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) + + labels2_after, _ = hnsw_index.knn_query(data2, k=1) + labels2_after_flat = labels2_after.flatten() + common = np.intersect1d(labels2_after_flat, labels2_deleted_flat) + self.assertTrue(common.size == 0) + + # Replace deleted elements + # Maximum number of elements is reached therefore we cannot add new items + # but we can replace the deleted ones + # Note: there may be less than num_elements elements. + # As we could delete less than num_elements because of duplicates + num_duplicates = len(labels2_deleted) - len(labels2_deleted_no_dup) + labels3_tr = labels3[0:labels3.shape[0] - num_duplicates] + data3_tr = data3[0:data3.shape[0] - num_duplicates] + hnsw_index.add_items(data3_tr, labels3_tr, replace_deleted=True)