diff --git a/include/vsag/allocator.h b/include/vsag/allocator.h index d612325f..449184c8 100644 --- a/include/vsag/allocator.h +++ b/include/vsag/allocator.h @@ -36,6 +36,22 @@ class Allocator { virtual void* Reallocate(void* p, size_t size) = 0; + template + T* + New(Args&&... args) { + void* p = Allocate(sizeof(T)); + return (T*)::new (p) T(std::forward(args)...); + } + + template + void + Delete(T* p) { + if (p) { + p->~T(); + Deallocate(static_cast(p)); + } + } + public: virtual ~Allocator() = default; }; diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index ce4b742a..da1a1b7d 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -74,7 +74,7 @@ class HierarchicalNSW : public AlgorithmInterface { double mult_{0.0}, revSize_{0.0}; int maxlevel_{0}; - std::shared_ptr visited_list_pool_{nullptr}; + VisitedListPool* visited_list_pool_{nullptr}; // Locks operations with element by label value mutable vsag::Vector label_op_locks_; @@ -169,8 +169,6 @@ class HierarchicalNSW : public AlgorithmInterface { cur_element_count_ = 0; - visited_list_pool_ = std::make_shared(1, max_elements, allocator_); - // initializations for special treatment of the first node enterpoint_node_ = -1; maxlevel_ = -1; @@ -890,7 +888,9 @@ class HierarchicalNSW : public AlgorithmInterface { throw std::runtime_error( "Cannot Resize, max element is less than the current number of elements"); - visited_list_pool_.reset(new VisitedListPool(1, new_max_elements, allocator_)); + auto new_visited_list_pool = allocator_->New(new_max_elements, allocator_); + allocator_->Delete(visited_list_pool_); + visited_list_pool_ = new_visited_list_pool; auto element_levels_new = (int*)allocator_->Reallocate(element_levels_, new_max_elements * sizeof(int)); @@ -1769,6 +1769,10 @@ class HierarchicalNSW : public AlgorithmInterface { void reset() { + if (visited_list_pool_) { + allocator_->Delete(visited_list_pool_); + visited_list_pool_ = nullptr; + } allocator_->Deallocate(element_levels_); element_levels_ = nullptr; allocator_->Deallocate(reversed_level0_link_list_); @@ -1784,6 +1788,7 @@ class HierarchicalNSW : public AlgorithmInterface { bool init_memory_space() override { reset(); + visited_list_pool_ = allocator_->New(max_elements_, allocator_); element_levels_ = (int*)allocator_->Allocate(max_elements_ * sizeof(int)); if (not data_level0_memory_->Resize(max_elements_)) { throw std::runtime_error("allocate data_level0_memory_ error"); diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 85748e62..c8d89d1c 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -161,7 +161,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface { cur_element_count_ = 0; - visited_list_pool_ = new VisitedListPool(1, max_elements, allocator_); + visited_list_pool_ = allocator_->New(max_elements, allocator_); // initializations for special treatment of the first node enterpoint_node_ = -1; @@ -187,7 +187,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface { } allocator_->Deallocate(element_levels_); allocator_->Deallocate(linkLists_); - delete visited_list_pool_; + allocator_->Delete(visited_list_pool_); CodeBook().swap(pq_book); allocator_->Deallocate(pq_map); allocator_->Deallocate(node_cluster_dist_); @@ -935,8 +935,8 @@ class StaticHierarchicalNSW : public AlgorithmInterface { throw std::runtime_error( "Cannot Resize, max element is less than the current number of elements"); - delete visited_list_pool_; - visited_list_pool_ = new VisitedListPool(1, new_max_elements, allocator_); + allocator_->Delete(visited_list_pool_); + visited_list_pool_ = allocator_->New(new_max_elements, allocator_); element_levels_ = (int*)allocator_->Reallocate(element_levels_, new_max_elements * sizeof(int)); @@ -1511,8 +1511,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface { std::vector(max_elements).swap(link_list_locks_); std::vector(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); - delete visited_list_pool_; - visited_list_pool_ = new VisitedListPool(1, max_elements, allocator_); + visited_list_pool_ = allocator_->New(max_elements, allocator_); free(linkLists_); linkLists_ = (char**)malloc(sizeof(void*) * max_elements); diff --git a/src/algorithm/hnswlib/visited_list_pool.h b/src/algorithm/hnswlib/visited_list_pool.h index f8d66c78..c0990b31 100644 --- a/src/algorithm/hnswlib/visited_list_pool.h +++ b/src/algorithm/hnswlib/visited_list_pool.h @@ -18,11 +18,11 @@ #include #include #include +#include #include #include "../../default_allocator.h" #include "stream_writer.h" - namespace vsag { extern void* @@ -41,18 +41,20 @@ typedef unsigned short int vl_type; class VisitedList { public: - vl_type curV; - vl_type* mass; - unsigned int numelements; + vl_type curV{0}; + vl_type* mass{nullptr}; + uint64_t numelements{0}; - VisitedList(int numelements1, vsag::Allocator* allocator) : allocator_(allocator) { + VisitedList(uint64_t numelements1, vsag::Allocator* allocator) : allocator_(allocator) { curV = -1; numelements = numelements1; - mass = (vl_type*)allocator_->Allocate(numelements * sizeof(vl_type)); } void reset() { + if (not mass) { + mass = (vl_type*)allocator_->Allocate(numelements * sizeof(vl_type)); + } curV++; if (curV == 0) { memset(mass, 0, sizeof(vl_type) * numelements); @@ -75,23 +77,20 @@ class VisitedList { class VisitedListPool { public: - VisitedListPool(int initmaxpools, int numelements1, vsag::Allocator* allocator) - : allocator_(allocator) { - numelements = numelements1; - for (int i = 0; i < initmaxpools; i++) - pool.push_front(std::make_shared(numelements, allocator_)); + VisitedListPool(uint64_t max_element_count, vsag::Allocator* allocator) + : allocator_(allocator), pool_(allocator), max_element_count_(max_element_count) { } std::shared_ptr getFreeVisitedList() { std::shared_ptr rez; { - std::unique_lock lock(poolguard); - if (pool.size() > 0) { - rez = pool.front(); - pool.pop_front(); + std::unique_lock lock(poolguard_); + if (not pool_.empty()) { + rez = pool_.front(); + pool_.pop_back(); } else { - rez = std::make_shared(numelements, allocator_); + rez = std::make_shared(max_element_count_, allocator_); } } rez->reset(); @@ -100,14 +99,14 @@ class VisitedListPool { void releaseVisitedList(std::shared_ptr vl) { - std::unique_lock lock(poolguard); - pool.push_front(vl); + std::unique_lock lock(poolguard_); + pool_.push_back(vl); } private: - std::deque> pool; - std::mutex poolguard; - uint64_t numelements; + vsag::Vector> pool_; + std::mutex poolguard_; + uint64_t max_element_count_; vsag::Allocator* allocator_; };