Skip to content

Commit

Permalink
fix memory leak for hnsw (#137)
Browse files Browse the repository at this point in the history
Signed-off-by: jinjiabao.jjb <jinjiabao.jjb@antgroup.com>
  • Loading branch information
inabao authored and jinjiabao.jjb committed Nov 25, 2024
1 parent d619c01 commit e900f25
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 31 deletions.
16 changes: 16 additions & 0 deletions include/vsag/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ class Allocator {
virtual void*
Reallocate(void* p, size_t size) = 0;

template <typename T, typename... Args>
T*
New(Args&&... args) {
void* p = Allocate(sizeof(T));
return (T*)::new (p) T(std::forward<Args>(args)...);
}

template <typename T>
void
Delete(T* p) {
if (p) {
p->~T();
Deallocate(static_cast<void*>(p));
}
}

public:
virtual ~Allocator() = default;
};
Expand Down
13 changes: 9 additions & 4 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
double mult_{0.0}, revSize_{0.0};
int maxlevel_{0};

std::shared_ptr<VisitedListPool> visited_list_pool_{nullptr};
VisitedListPool* visited_list_pool_{nullptr};

// Locks operations with element by label value
mutable vsag::Vector<std::mutex> label_op_locks_;
Expand Down Expand Up @@ -169,8 +169,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {

cur_element_count_ = 0;

visited_list_pool_ = std::make_shared<VisitedListPool>(1, max_elements, allocator_);

// initializations for special treatment of the first node
enterpoint_node_ = -1;
maxlevel_ = -1;
Expand Down Expand Up @@ -890,7 +888,9 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
throw std::runtime_error(
"Cannot Resize, max element is less than the current number of elements");

visited_list_pool_.reset(new VisitedListPool(1, new_max_elements, allocator_));
auto new_visited_list_pool = allocator_->New<VisitedListPool>(new_max_elements, allocator_);
allocator_->Delete(visited_list_pool_);
visited_list_pool_ = new_visited_list_pool;

auto element_levels_new =
(int*)allocator_->Reallocate(element_levels_, new_max_elements * sizeof(int));
Expand Down Expand Up @@ -1769,6 +1769,10 @@ class HierarchicalNSW : public AlgorithmInterface<float> {

void
reset() {
if (visited_list_pool_) {
allocator_->Delete(visited_list_pool_);
visited_list_pool_ = nullptr;
}
allocator_->Deallocate(element_levels_);
element_levels_ = nullptr;
allocator_->Deallocate(reversed_level0_link_list_);
Expand All @@ -1784,6 +1788,7 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
bool
init_memory_space() override {
reset();
visited_list_pool_ = allocator_->New<VisitedListPool>(max_elements_, allocator_);
element_levels_ = (int*)allocator_->Allocate(max_elements_ * sizeof(int));
if (not data_level0_memory_->Resize(max_elements_)) {
throw std::runtime_error("allocate data_level0_memory_ error");
Expand Down
11 changes: 5 additions & 6 deletions src/algorithm/hnswlib/hnswalg_static.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {

cur_element_count_ = 0;

visited_list_pool_ = new VisitedListPool(1, max_elements, allocator_);
visited_list_pool_ = allocator_->New<VisitedListPool>(max_elements, allocator_);

// initializations for special treatment of the first node
enterpoint_node_ = -1;
Expand All @@ -187,7 +187,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {
}
allocator_->Deallocate(element_levels_);
allocator_->Deallocate(linkLists_);
delete visited_list_pool_;
allocator_->Delete(visited_list_pool_);
CodeBook().swap(pq_book);
allocator_->Deallocate(pq_map);
allocator_->Deallocate(node_cluster_dist_);
Expand Down Expand Up @@ -935,8 +935,8 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {
throw std::runtime_error(
"Cannot Resize, max element is less than the current number of elements");

delete visited_list_pool_;
visited_list_pool_ = new VisitedListPool(1, new_max_elements, allocator_);
allocator_->Delete(visited_list_pool_);
visited_list_pool_ = allocator_->New<VisitedListPool>(new_max_elements, allocator_);

element_levels_ =
(int*)allocator_->Reallocate(element_levels_, new_max_elements * sizeof(int));
Expand Down Expand Up @@ -1511,8 +1511,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {
std::vector<std::mutex>(max_elements).swap(link_list_locks_);
std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);

delete visited_list_pool_;
visited_list_pool_ = new VisitedListPool(1, max_elements, allocator_);
visited_list_pool_ = allocator_->New<VisitedListPool>(max_elements, allocator_);

free(linkLists_);
linkLists_ = (char**)malloc(sizeof(void*) * max_elements);
Expand Down
41 changes: 20 additions & 21 deletions src/algorithm/hnswlib/visited_list_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
#include <cstring>
#include <deque>
#include <functional>
#include <iostream>
#include <mutex>

#include "../../default_allocator.h"
#include "stream_writer.h"

namespace vsag {

extern void*
Expand All @@ -41,18 +41,20 @@ typedef unsigned short int vl_type;

class VisitedList {
public:
vl_type curV;
vl_type* mass;
unsigned int numelements;
vl_type curV{0};
vl_type* mass{nullptr};
uint64_t numelements{0};

VisitedList(int numelements1, vsag::Allocator* allocator) : allocator_(allocator) {
VisitedList(uint64_t numelements1, vsag::Allocator* allocator) : allocator_(allocator) {
curV = -1;
numelements = numelements1;
mass = (vl_type*)allocator_->Allocate(numelements * sizeof(vl_type));
}

void
reset() {
if (not mass) {
mass = (vl_type*)allocator_->Allocate(numelements * sizeof(vl_type));
}
curV++;
if (curV == 0) {
memset(mass, 0, sizeof(vl_type) * numelements);
Expand All @@ -75,23 +77,20 @@ class VisitedList {

class VisitedListPool {
public:
VisitedListPool(int initmaxpools, int numelements1, vsag::Allocator* allocator)
: allocator_(allocator) {
numelements = numelements1;
for (int i = 0; i < initmaxpools; i++)
pool.push_front(std::make_shared<VisitedList>(numelements, allocator_));
VisitedListPool(uint64_t max_element_count, vsag::Allocator* allocator)
: allocator_(allocator), pool_(allocator), max_element_count_(max_element_count) {
}

std::shared_ptr<VisitedList>
getFreeVisitedList() {
std::shared_ptr<VisitedList> rez;
{
std::unique_lock<std::mutex> lock(poolguard);
if (pool.size() > 0) {
rez = pool.front();
pool.pop_front();
std::unique_lock<std::mutex> lock(poolguard_);
if (not pool_.empty()) {
rez = pool_.front();
pool_.pop_back();
} else {
rez = std::make_shared<VisitedList>(numelements, allocator_);
rez = std::make_shared<VisitedList>(max_element_count_, allocator_);
}
}
rez->reset();
Expand All @@ -100,14 +99,14 @@ class VisitedListPool {

void
releaseVisitedList(std::shared_ptr<VisitedList> vl) {
std::unique_lock<std::mutex> lock(poolguard);
pool.push_front(vl);
std::unique_lock<std::mutex> lock(poolguard_);
pool_.push_back(vl);
}

private:
std::deque<std::shared_ptr<VisitedList>> pool;
std::mutex poolguard;
uint64_t numelements;
vsag::Vector<std::shared_ptr<VisitedList>> pool_;
std::mutex poolguard_;
uint64_t max_element_count_;
vsag::Allocator* allocator_;
};

Expand Down

0 comments on commit e900f25

Please sign in to comment.