Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use unique_ptr to manage visited_list_pool_ #474

Merged
merged 2 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,16 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {


void removePoint(labeltype cur_external) {
size_t cur_c = dict_external_to_internal[cur_external];
std::unique_lock<std::mutex> lock(index_lock);

dict_external_to_internal.erase(cur_external);
auto found = dict_external_to_internal.find(cur_external);
if (found == dict_external_to_internal.end()) {
return;
}

dict_external_to_internal.erase(found);

size_t cur_c = found->second;
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
dict_external_to_internal[label] = cur_c;
memcpy(data_ + size_per_element_ * cur_c,
Expand Down
11 changes: 5 additions & 6 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <assert.h>
#include <unordered_set>
#include <list>
#include <memory>

namespace hnswlib {
typedef unsigned int tableint;
Expand All @@ -33,7 +34,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
double mult_{0.0}, revSize_{0.0};
int maxlevel_{0};

VisitedListPool *visited_list_pool_{nullptr};
std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr};

// Locks operations with element by label value
mutable std::vector<std::mutex> label_op_locks_;
Expand Down Expand Up @@ -122,7 +123,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

cur_element_count = 0;

visited_list_pool_ = new VisitedListPool(1, max_elements);
visited_list_pool_ = std::unique_ptr<VisitedListPool>(new VisitedListPool(1, max_elements));

// initializations for special treatment of the first node
enterpoint_node_ = -1;
Expand All @@ -144,7 +145,6 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
free(linkLists_[i]);
}
free(linkLists_);
delete visited_list_pool_;
}


Expand Down Expand Up @@ -573,8 +573,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
if (new_max_elements < cur_element_count)
throw std::runtime_error("Cannot resize, max element is less than the current number of elements");

delete visited_list_pool_;
visited_list_pool_ = new VisitedListPool(1, new_max_elements);
visited_list_pool_.reset(new VisitedListPool(1, new_max_elements));

element_levels_.resize(new_max_elements);

Expand Down Expand Up @@ -724,7 +723,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::vector<std::mutex>(max_elements).swap(link_list_locks_);
std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);

visited_list_pool_ = new VisitedListPool(1, max_elements);
visited_list_pool_.reset(new VisitedListPool(1, max_elements));

linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
if (linkLists_ == nullptr)
Expand Down