Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: filter elements with an optional filtering function #417

Merged
merged 4 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 37 additions & 30 deletions examples/searchKnnWithFilter_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,25 @@ namespace {

using idx_t = hnswlib::labeltype;

bool pickIdsDivisibleByThree(unsigned int label_id) {
return label_id % 3 == 0;
}

bool pickIdsDivisibleBySeven(unsigned int label_id) {
return label_id % 7 == 0;
}
class PickDivisibleIds: public hnswlib::BaseFilterFunctor {
unsigned int divisor = 1;
public:
PickDivisibleIds(unsigned int divisor): divisor(divisor) {
assert(divisor != 0);
}
bool operator()(idx_t label_id) {
return label_id % divisor == 0;
}
};

bool pickNothing(unsigned int label_id) {
return false;
}
class PickNothing: public hnswlib::BaseFilterFunctor {
public:
bool operator()(idx_t label_id) {
return false;
}
};

template<typename filter_func_t>
void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t label_id_start) {
void test_some_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t div_num, size_t label_id_start) {
int d = 4;
idx_t n = 100;
idx_t nq = 10;
Expand All @@ -45,8 +50,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
}

hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float, filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float, filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
Expand All @@ -57,8 +62,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
// test searchKnnCloserFirst of BruteforceSearch with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_brute->searchKnn(p, k, filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func);
auto gd = alg_brute->searchKnn(p, k, &filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
Expand All @@ -71,8 +76,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
// test searchKnnCloserFirst of hnsw with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_hnsw->searchKnn(p, k, filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func);
auto gd = alg_hnsw->searchKnn(p, k, &filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
Expand All @@ -86,8 +91,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
delete alg_hnsw;
}

template<typename filter_func_t>
void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
void test_none_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t label_id_start) {
int d = 4;
idx_t n = 100;
idx_t nq = 10;
Expand All @@ -108,8 +112,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
}

hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float, filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float, filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
Expand All @@ -120,17 +124,17 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
// test searchKnnCloserFirst of BruteforceSearch with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_brute->searchKnn(p, k, filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func);
auto gd = alg_brute->searchKnn(p, k, &filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func);
assert(gd.size() == res.size());
assert(0 == gd.size());
}

// test searchKnnCloserFirst of hnsw with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_hnsw->searchKnn(p, k, filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func);
auto gd = alg_hnsw->searchKnn(p, k, &filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func);
assert(gd.size() == res.size());
assert(0 == gd.size());
}
Expand All @@ -141,13 +145,13 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {

} // namespace

class CustomFilterFunctor: public hnswlib::FilterFunctor {
std::unordered_set<unsigned int> allowed_values;
class CustomFilterFunctor: public hnswlib::BaseFilterFunctor {
std::unordered_set<idx_t> allowed_values;

public:
explicit CustomFilterFunctor(const std::unordered_set<unsigned int>& values) : allowed_values(values) {}
explicit CustomFilterFunctor(const std::unordered_set<idx_t>& values) : allowed_values(values) {}

bool operator()(unsigned int id) {
bool operator()(idx_t id) {
return allowed_values.count(id) != 0;
}
};
Expand All @@ -156,10 +160,13 @@ int main() {
std::cout << "Testing ..." << std::endl;

// some of the elements are filtered
PickDivisibleIds pickIdsDivisibleByThree(3);
test_some_filtering(pickIdsDivisibleByThree, 3, 17);
PickDivisibleIds pickIdsDivisibleBySeven(7);
test_some_filtering(pickIdsDivisibleBySeven, 7, 17);

// all of the elements are filtered
PickNothing pickNothing;
test_none_filtering(pickNothing, 17);

// functor style which can capture context
Expand Down
11 changes: 5 additions & 6 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include <assert.h>

namespace hnswlib {
template<typename dist_t, typename filter_func_t = FilterFunctor>
class BruteforceSearch : public AlgorithmInterface<dist_t, filter_func_t> {
template<typename dist_t>
class BruteforceSearch : public AlgorithmInterface<dist_t> {
public:
char *data_;
size_t maxelements_;
Expand Down Expand Up @@ -98,15 +98,14 @@ class BruteforceSearch : public AlgorithmInterface<dist_t, filter_func_t> {


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const {
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
assert(k <= cur_element_count);
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
if (cur_element_count == 0) return topResults;
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;
for (int i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
if (is_filter_disabled || isIdAllowed(label)) {
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
}
Expand All @@ -115,7 +114,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t, filter_func_t> {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
if (dist <= lastdist) {
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
if (is_filter_disabled || isIdAllowed(label)) {
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
if (topResults.size() > k)
Expand Down
13 changes: 6 additions & 7 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace hnswlib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;

template<typename dist_t, typename filter_func_t = FilterFunctor>
class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {
template<typename dist_t>
class HierarchicalNSW : public AlgorithmInterface<dist_t> {
public:
static const tableint max_update_element_locks = 65536;
static const unsigned char DELETE_MARK = 0x01;
Expand Down Expand Up @@ -268,7 +268,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {

template <bool has_deletions, bool collect_metrics = false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const {
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
Expand All @@ -277,8 +277,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {
if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
Expand Down Expand Up @@ -336,7 +335,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {
_MM_HINT_T0); ////////////////////////
#endif

if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id))))
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
top_candidates.emplace(dist, candidate_id);

if (top_candidates.size() > ef)
Expand Down Expand Up @@ -1083,7 +1082,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const {
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

Expand Down
20 changes: 9 additions & 11 deletions hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,11 @@ namespace hnswlib {
typedef size_t labeltype;

// This can be extended to store state for filtering (e.g. from a std::set)
struct FilterFunctor {
template<class...Args>
bool operator()(Args&&...) { return true; }
class BaseFilterFunctor {
public:
virtual bool operator()(hnswlib::labeltype id) { return true; }
};

static FilterFunctor allowAllIds;

template <typename T>
class pairGreater {
public:
Expand Down Expand Up @@ -157,27 +155,27 @@ class SpaceInterface {
virtual ~SpaceInterface() {}
};

template<typename dist_t, typename filter_func_t = FilterFunctor>
template<typename dist_t>
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label) = 0;

virtual std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void*, size_t, filter_func_t& isIdAllowed = allowAllIds) const = 0;
searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;

// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const;
searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;

virtual void saveIndex(const std::string &location) = 0;
virtual ~AlgorithmInterface(){
}
};

template<typename dist_t, typename filter_func_t>
template<typename dist_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t, filter_func_t>::searchKnnCloserFirst(const void* query_data, size_t k,
filter_func_t& isIdAllowed) const {
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
BaseFilterFunctor* isIdAllowed) const {
std::vector<std::pair<dist_t, labeltype>> result;

// here searchKnn returns the result in the order of further first
Expand Down
42 changes: 35 additions & 7 deletions python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <iostream>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -79,6 +80,20 @@ inline void assert_true(bool expr, const std::string & msg) {
}


class CustomFilterFunctor: public hnswlib::BaseFilterFunctor {
std::function<bool(hnswlib::labeltype)> filter;

public:
explicit CustomFilterFunctor(const std::function<bool(hnswlib::labeltype)>& f) {
filter = f;
}

bool operator()(hnswlib::labeltype id) {
return filter(id);
}
};


inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) {
if (buffer.ndim != 2 && buffer.ndim != 1) {
char msg[256];
Expand Down Expand Up @@ -573,7 +588,11 @@ class Index {
}


py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) {
py::object knnQuery_return_numpy(
py::object input,
size_t k = 1,
int num_threads = -1,
const std::function<bool(hnswlib::labeltype)>& filter = nullptr) {
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
auto buffer = items.request();
hnswlib::labeltype* data_numpy_l;
Expand All @@ -595,10 +614,13 @@ class Index {
data_numpy_l = new hnswlib::labeltype[rows * k];
data_numpy_d = new dist_t[rows * k];

CustomFilterFunctor idFilter(filter);
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr;

if (normalize == false) {
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = appr_alg->searchKnn(
(void*)items.data(row), k);
(void*)items.data(row), k, p_idFilter);
if (result.size() != k)
throw std::runtime_error(
"Cannot return the results in a contigious 2D array. Probably ef or M is too small");
Expand All @@ -618,7 +640,7 @@ class Index {
normalize_vector((float*)items.data(row), (norm_array.data() + start_idx));

std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = appr_alg->searchKnn(
(void*)(norm_array.data() + start_idx), k);
(void*)(norm_array.data() + start_idx), k, p_idFilter);
if (result.size() != k)
throw std::runtime_error(
"Cannot return the results in a contigious 2D array. Probably ef or M is too small");
Expand Down Expand Up @@ -785,7 +807,10 @@ class BFIndex {
}


py::object knnQuery_return_numpy(py::object input, size_t k = 1) {
py::object knnQuery_return_numpy(
py::object input,
size_t k = 1,
const std::function<bool(hnswlib::labeltype)>& filter = nullptr) {
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
auto buffer = items.request();
hnswlib::labeltype *data_numpy_l;
Expand All @@ -799,9 +824,12 @@ class BFIndex {
data_numpy_l = new hnswlib::labeltype[rows * k];
data_numpy_d = new dist_t[rows * k];

CustomFilterFunctor idFilter(filter);
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr;

for (size_t row = 0; row < rows; row++) {
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
(void *) items.data(row), k);
(void *) items.data(row), k, p_idFilter);
for (int i = k - 1; i >= 0; i--) {
auto &result_tuple = result.top();
data_numpy_d[row * k + i] = result_tuple.first;
Expand Down Expand Up @@ -844,7 +872,7 @@ PYBIND11_PLUGIN(hnswlib) {
.def(py::init(&Index<float>::createFromIndex), py::arg("index"))
.def(py::init<const std::string &, const int>(), py::arg("space"), py::arg("dim"))
.def("init_index", &Index<float>::init_new_index, py::arg("max_elements"), py::arg("M") = 16, py::arg("ef_construction") = 200, py::arg("random_seed") = 100)
.def("knn_query", &Index<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1)
.def("knn_query", &Index<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1, py::arg("filter") = py::none())
.def("add_items", &Index<float>::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads") = -1)
.def("get_items", &Index<float, float>::getDataReturnList, py::arg("ids") = py::none())
.def("get_ids_list", &Index<float>::getIdsList)
Expand Down Expand Up @@ -899,7 +927,7 @@ PYBIND11_PLUGIN(hnswlib) {
py::class_<BFIndex<float>>(m, "BFIndex")
.def(py::init<const std::string &, const int>(), py::arg("space"), py::arg("dim"))
.def("init_index", &BFIndex<float>::init_new_index, py::arg("max_elements"))
.def("knn_query", &BFIndex<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1)
.def("knn_query", &BFIndex<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("filter") = py::none())
.def("add_items", &BFIndex<float>::addItems, py::arg("data"), py::arg("ids") = py::none())
.def("delete_vector", &BFIndex<float>::deleteVector, py::arg("label"))
.def("save_index", &BFIndex<float>::saveIndex, py::arg("path_to_index"))
Expand Down
Loading