Skip to content

Commit

Permalink
CodeSet for deduping large datasets
Browse files Browse the repository at this point in the history
Differential Revision: D47443953

fbshipit-source-id: 6701f985d56638cd0ff33217d5069cf8a86f9fe6
  • Loading branch information
algoriddle authored and facebook-github-bot committed Jul 13, 2023
1 parent 43d86e3 commit 73563ff
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
class_wrappers.handle_IDSelectorSubset(IDSelectorBatch, class_owns=True)
class_wrappers.handle_IDSelectorSubset(IDSelectorArray, class_owns=False)
class_wrappers.handle_IDSelectorSubset(IDSelectorBitmap, class_owns=False, force_int64=False)
class_wrappers.handle_CodeSet(CodeSet)

this_module = sys.modules[__name__]

Expand Down
18 changes: 18 additions & 0 deletions faiss/python/class_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,3 +1102,21 @@ def replacement_init(self, *args):
self.original_init(*args)

the_class.__init__ = replacement_init


def handle_CodeSet(the_class):

def replacement_insert(self, codes, inserted=None):
n, d = codes.shape
assert d == self.d
codes = np.ascontiguousarray(codes, dtype=np.uint8)

if inserted is None:
inserted = np.empty(n, dtype=np.bool)
else:
assert inserted.shape == (n, )

self.insert_c(n, swig_ptr(codes), swig_ptr(inserted))
return inserted

replace_method(the_class, 'insert', replacement_insert)
9 changes: 9 additions & 0 deletions faiss/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <omp.h>

#include <algorithm>
#include <set>
#include <type_traits>
#include <vector>

Expand Down Expand Up @@ -623,4 +624,12 @@ void CombinerRangeKNN<T>::write_result(T* D_res, int64_t* I_res) {
template struct CombinerRangeKNN<float>;
template struct CombinerRangeKNN<int16_t>;

void CodeSet::insert(size_t n, const uint8_t* codes, bool* inserted) {
for (size_t i = 0; i < n; i++) {
auto res = s.insert(
std::vector<uint8_t>(codes + i * d, codes + i * d + d));
inserted[i] = res.second;
}
}

} // namespace faiss
10 changes: 10 additions & 0 deletions faiss/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#define FAISS_utils_h

#include <stdint.h>
#include <set>
#include <string>
#include <vector>

#include <faiss/impl/platform_macros.h>
#include <faiss/utils/Heap.h>
Expand Down Expand Up @@ -209,6 +211,14 @@ struct CombinerRangeKNN {
void write_result(T* D_res, int64_t* I_res);
};

struct CodeSet {
size_t d;
std::set<std::vector<uint8_t>> s;

explicit CodeSet(size_t d) : d(d) {}
void insert(size_t n, const uint8_t* codes, bool* inserted);
};

} // namespace faiss

#endif /* FAISS_utils_h */
14 changes: 14 additions & 0 deletions tests/test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,17 @@ def test_hnsw_permute(self):
np.testing.assert_equal(Dnew, Dref)
Inew_remap = perm[Inew]
np.testing.assert_equal(Inew_remap, Iref)


class TestCodeSet(unittest.TestCase):

def test_code_set(self):
""" CodeSet and np.unique should produce the same output """
d = 8
n = 1000 # > 256 and using only 0 or 1 so there must be duplicates
codes = np.random.randint(0, 2, (n, d), dtype=np.uint8)
s = faiss.CodeSet(d)
inserted = s.insert(codes)
np.testing.assert_equal(
np.sort(np.unique(codes, axis=0), axis=None),
np.sort(codes[inserted], axis=None))

0 comments on commit 73563ff

Please sign in to comment.