Skip to content

Commit

Permalink
use dispatcher function to call HammingComputer (facebookresearch#2918)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#2918

The HammingComputer class is optimized for several vector sizes. So far it's been the caller's responsiblity to instanciate the relevant optimized version.

This diff introduces a `dispatch_HammingComputer` function that can be called with a template class that is instanciated for all existing optimized HammingComputer's.

Reviewed By: algoriddle

Differential Revision: D46858553

fbshipit-source-id: 32c31689bba7c0b406b309fc8574c95fa24022ba
  • Loading branch information
mdouze authored and facebook-github-bot committed Jun 26, 2023
1 parent a27036a commit a91a288
Show file tree
Hide file tree
Showing 13 changed files with 365 additions and 686 deletions.
60 changes: 60 additions & 0 deletions benchs/bench_hamming_computer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,66 @@

using namespace faiss;

// These implementations are currently slower than HammingComputerDefault so
// they are not in the main faiss anymore.
struct HammingComputerM8 {
const uint64_t* a;
int n;

HammingComputerM8() {}

HammingComputerM8(const uint8_t* a8, int code_size) {
set(a8, code_size);
}

void set(const uint8_t* a8, int code_size) {
assert(code_size % 8 == 0);
a = (uint64_t*)a8;
n = code_size / 8;
}

int hamming(const uint8_t* b8) const {
const uint64_t* b = (uint64_t*)b8;
int accu = 0;
for (int i = 0; i < n; i++)
accu += popcount64(a[i] ^ b[i]);
return accu;
}

inline int get_code_size() const {
return n * 8;
}
};

struct HammingComputerM4 {
const uint32_t* a;
int n;

HammingComputerM4() {}

HammingComputerM4(const uint8_t* a4, int code_size) {
set(a4, code_size);
}

void set(const uint8_t* a4, int code_size) {
assert(code_size % 4 == 0);
a = (uint32_t*)a4;
n = code_size / 4;
}

int hamming(const uint8_t* b8) const {
const uint32_t* b = (uint32_t*)b8;
int accu = 0;
for (int i = 0; i < n; i++)
accu += popcount64(a[i] ^ b[i]);
return accu;
}

inline int get_code_size() const {
return n * 4;
}
};

template <class T>
void hamming_cpt_test(
int code_size,
Expand Down
30 changes: 10 additions & 20 deletions faiss/IndexBinaryHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,31 +281,21 @@ struct FlatHammingDis : DistanceComputer {
}
};

struct BuildDistanceComputer {
using T = DistanceComputer*;
template <class HammingComputer>
DistanceComputer* f(IndexBinaryFlat* flat_storage) {
return new FlatHammingDis<HammingComputer>(*flat_storage);
}
};

} // namespace

DistanceComputer* IndexBinaryHNSW::get_distance_computer() const {
IndexBinaryFlat* flat_storage = dynamic_cast<IndexBinaryFlat*>(storage);

FAISS_ASSERT(flat_storage != nullptr);

switch (code_size) {
case 4:
return new FlatHammingDis<HammingComputer4>(*flat_storage);
case 8:
return new FlatHammingDis<HammingComputer8>(*flat_storage);
case 16:
return new FlatHammingDis<HammingComputer16>(*flat_storage);
case 20:
return new FlatHammingDis<HammingComputer20>(*flat_storage);
case 32:
return new FlatHammingDis<HammingComputer32>(*flat_storage);
case 64:
return new FlatHammingDis<HammingComputer64>(*flat_storage);
default:
break;
}

return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
BuildDistanceComputer bd;
return dispatch_HammingComputer(code_size, bd, flat_storage);
}

} // namespace faiss
74 changes: 25 additions & 49 deletions faiss/IndexBinaryHash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ void search_single_query_template(
} while (fe.next());
}

struct Run_search_single_query {
using T = void;
template <class HammingComputer, class... Types>
T f(Types... args) {
search_single_query_template<HammingComputer>(args...);
}
};

template <class SearchResults>
void search_single_query(
const IndexBinaryHash& index,
Expand All @@ -184,29 +192,9 @@ void search_single_query(
size_t& n0,
size_t& nlist,
size_t& ndis) {
#define HC(name) \
search_single_query_template<name>(index, q, res, n0, nlist, ndis);
switch (index.code_size) {
case 4:
HC(HammingComputer4);
break;
case 8:
HC(HammingComputer8);
break;
case 16:
HC(HammingComputer16);
break;
case 20:
HC(HammingComputer20);
break;
case 32:
HC(HammingComputer32);
break;
default:
HC(HammingComputerDefault);
break;
}
#undef HC
Run_search_single_query r;
dispatch_HammingComputer(
index.code_size, r, index, q, res, n0, nlist, ndis);
}

} // anonymous namespace
Expand Down Expand Up @@ -349,22 +337,30 @@ namespace {

template <class HammingComputer, class SearchResults>
static void verify_shortlist(
const IndexBinaryFlat& index,
const IndexBinaryFlat* index,
const uint8_t* q,
const std::unordered_set<idx_t>& shortlist,
SearchResults& res) {
size_t code_size = index.code_size;
size_t code_size = index->code_size;
size_t nlist = 0, ndis = 0, n0 = 0;

HammingComputer hc(q, code_size);
const uint8_t* codes = index.xb.data();
const uint8_t* codes = index->xb.data();

for (auto i : shortlist) {
int dis = hc.hamming(codes + i * code_size);
res.add(dis, i);
}
}

struct Run_verify_shortlist {
using T = void;
template <class HammingComputer, class... Types>
void f(Types... args) {
verify_shortlist<HammingComputer>(args...);
}
};

template <class SearchResults>
void search_1_query_multihash(
const IndexBinaryMultiHash& index,
Expand Down Expand Up @@ -405,29 +401,9 @@ void search_1_query_multihash(
ndis += shortlist.size();

// verify shortlist

#define HC(name) verify_shortlist<name>(*index.storage, xi, shortlist, res)
switch (index.code_size) {
case 4:
HC(HammingComputer4);
break;
case 8:
HC(HammingComputer8);
break;
case 16:
HC(HammingComputer16);
break;
case 20:
HC(HammingComputer20);
break;
case 32:
HC(HammingComputer32);
break;
default:
HC(HammingComputerDefault);
break;
}
#undef HC
Run_verify_shortlist r;
dispatch_HammingComputer(
index.code_size, r, index.storage, xi, shortlist, res);
}

} // anonymous namespace
Expand Down
Loading

0 comments on commit a91a288

Please sign in to comment.