Skip to content

Commit

Permalink
Binary cloning and GPU range search (facebookresearch#2916)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#2916

Overall better support for binary indexes:
- cloning (to CPU and GPU), only for BinaryFlat for now
- fix bug in reconstruct_n
- range_search_max_results

Reviewed By: algoriddle

Differential Revision: D46755778

fbshipit-source-id: 777ad90aff5c54a77f9685ed6512247a922c6ef5
  • Loading branch information
mdouze authored and Thejas-bhat committed Sep 26, 2023
1 parent 3e0a6d6 commit a90d51e
Show file tree
Hide file tree
Showing 22 changed files with 384 additions and 116 deletions.
39 changes: 29 additions & 10 deletions contrib/exhaustive_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,18 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
- None. In that case, at most gpu_k results will be returned
"""
nq, d = xq.shape
k = min(index_gpu.ntotal, gpu_k)
is_binary_index = isinstance(index_gpu, faiss.IndexBinary)
keep_max = faiss.is_similarity_metric(index_gpu.metric_type)
LOG.debug(f"GPU search {nq} queries with {k=:}")
r2 = int(r2) if is_binary_index else float(r2)
k = min(index_gpu.ntotal, gpu_k)
LOG.debug(
f"GPU search {nq} queries with {k=:} {is_binary_index=:} {keep_max=:}")
t0 = time.time()
D, I = index_gpu.search(xq, k)
t1 = time.time() - t0
if is_binary_index:
assert d * 8 < 32768 # let's compact the distance matrix
D = D.astype('int16')
t2 = 0
lim_remain = None
if index_cpu is not None:
Expand All @@ -79,14 +85,24 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
if isinstance(index_cpu, np.ndarray):
# then it in fact an array that we have to make flat
xb = index_cpu
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
if is_binary_index:
index_cpu = faiss.IndexBinaryFlat(d * 8)
else:
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
index_cpu.add(xb)
lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
if is_binary_index:
D_remain = D_remain.astype('int16')
t2 = time.time() - t0
LOG.debug("combine")
t0 = time.time()

combiner = faiss.CombinerRangeKNN(nq, k, float(r2), keep_max)
CombinerRangeKNN = (
faiss.CombinerRangeKNNint16 if is_binary_index else
faiss.CombinerRangeKNNfloat
)

combiner = CombinerRangeKNN(nq, k, r2, keep_max)
if True:
sp = faiss.swig_ptr
combiner.I = sp(I)
Expand All @@ -101,7 +117,7 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
L_res = np.empty(nq + 1, dtype='int64')
combiner.compute_sizes(sp(L_res))
nres = L_res[-1]
D_res = np.empty(nres, dtype='float32')
D_res = np.empty(nres, dtype=D.dtype)
I_res = np.empty(nres, dtype='int64')
combiner.write_result(sp(D_res), sp(I_res))
else:
Expand Down Expand Up @@ -251,6 +267,7 @@ def range_search_max_results(index, query_iterator, radius,
"""
# TODO: all result manipulations are in python, should move to C++ if perf
# critical
is_binary_index = isinstance(index, faiss.IndexBinary)

if min_results is None:
assert max_results is not None
Expand All @@ -268,6 +285,8 @@ def range_search_max_results(index, query_iterator, radius,
co = faiss.GpuMultipleClonerOptions()
co.shard = shard
index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
else:
index_gpu = None

t_start = time.time()
t_search = t_post_process = 0
Expand All @@ -276,7 +295,8 @@ def range_search_max_results(index, query_iterator, radius,

for xqi in query_iterator:
t0 = time.time()
if ngpu > 0:
LOG.debug(f"searching {len(xqi)} vectors")
if index_gpu:
lims_i, Di, Ii = range_search_gpu(xqi, radius, index_gpu, index)
else:
lims_i, Di, Ii = index.range_search(xqi, radius)
Expand All @@ -286,8 +306,7 @@ def range_search_max_results(index, query_iterator, radius,
qtot += len(xqi)

t1 = time.time()
if xqi.dtype != np.float32:
# for binary indexes
if is_binary_index:
# weird Faiss quirk that returns floats for Hamming distances
Di = Di.astype('int16')

Expand All @@ -299,7 +318,7 @@ def range_search_max_results(index, query_iterator, radius,
(totres, max_results))
radius, totres = apply_maxres(
res_batches, min_results,
keep_max=faiss.is_similarity_metric(index.metric_type)
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
)
t2 = time.time()
t_search += t1 - t0
Expand All @@ -315,7 +334,7 @@ def range_search_max_results(index, query_iterator, radius,
if clip_to_min and totres > min_results:
radius, totres = apply_maxres(
res_batches, min_results,
keep_max=faiss.is_similarity_metric(index.metric_type)
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
)

nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches])
Expand Down
11 changes: 8 additions & 3 deletions faiss/IndexBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@

namespace faiss {

IndexBinary::IndexBinary(idx_t d, MetricType metric)
: d(d), code_size(d / 8), metric_type(metric) {
FAISS_THROW_IF_NOT(d % 8 == 0);
}

IndexBinary::~IndexBinary() {}

void IndexBinary::train(idx_t, const uint8_t*) {
Expand Down Expand Up @@ -51,7 +56,7 @@ void IndexBinary::reconstruct(idx_t, uint8_t*) const {

void IndexBinary::reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const {
for (idx_t i = 0; i < ni; i++) {
reconstruct(i0 + i, recons + i * d);
reconstruct(i0 + i, recons + i * code_size);
}
}

Expand All @@ -70,10 +75,10 @@ void IndexBinary::search_and_reconstruct(
for (idx_t j = 0; j < k; ++j) {
idx_t ij = i * k + j;
idx_t key = labels[ij];
uint8_t* reconstructed = recons + ij * d;
uint8_t* reconstructed = recons + ij * code_size;
if (key < 0) {
// Fill with NaNs
memset(reconstructed, -1, sizeof(*reconstructed) * d);
memset(reconstructed, -1, code_size);
} else {
reconstruct(key, reconstructed);
}
Expand Down
27 changes: 8 additions & 19 deletions faiss/IndexBinary.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#ifndef FAISS_INDEX_BINARY_H
#define FAISS_INDEX_BINARY_H

Expand All @@ -16,7 +14,6 @@
#include <typeinfo>

#include <faiss/Index.h>
#include <faiss/impl/FaissAssert.h>

namespace faiss {

Expand All @@ -35,27 +32,19 @@ struct IndexBinary {
using component_t = uint8_t;
using distance_t = int32_t;

int d; ///< vector dimension
int code_size; ///< number of bytes per vector ( = d / 8 )
idx_t ntotal; ///< total nb of indexed vectors
bool verbose; ///< verbosity level
int d = 0; ///< vector dimension
int code_size = 0; ///< number of bytes per vector ( = d / 8 )
idx_t ntotal = 0; ///< total nb of indexed vectors
bool verbose = false; ///< verbosity level

/// set if the Index does not require training, or if training is done
/// already
bool is_trained;
bool is_trained = true;

/// type of metric this index uses for search
MetricType metric_type;

explicit IndexBinary(idx_t d = 0, MetricType metric = METRIC_L2)
: d(d),
code_size(d / 8),
ntotal(0),
verbose(false),
is_trained(true),
metric_type(metric) {
FAISS_THROW_IF_NOT(d % 8 == 0);
}
MetricType metric_type = METRIC_L2;

explicit IndexBinary(idx_t d = 0, MetricType metric = METRIC_L2);

virtual ~IndexBinary();

Expand Down
1 change: 1 addition & 0 deletions faiss/IndexBinaryFromFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <faiss/IndexBinaryFromFloat.h>

#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/utils.h>
#include <algorithm>
#include <memory>
Expand Down
17 changes: 15 additions & 2 deletions faiss/IndexIDMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,31 @@

namespace faiss {

namespace {

// IndexBinary needs to update the code_size when d is set...

void sync_d(Index* index) {}

void sync_d(IndexBinary* index) {
FAISS_THROW_IF_NOT(index->d % 8 == 0);
index->code_size = index->d / 8;
}

} // anonymous namespace

/*****************************************************
* IndexIDMap implementation
*******************************************************/

template <typename IndexT>
IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index)
: index(index), own_fields(false) {
IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index) : index(index) {
FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input");
this->is_trained = index->is_trained;
this->metric_type = index->metric_type;
this->verbose = index->verbose;
this->d = index->d;
sync_d(this);
}

template <typename IndexT>
Expand Down
4 changes: 2 additions & 2 deletions faiss/IndexIDMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ struct IndexIDMapTemplate : IndexT {
using component_t = typename IndexT::component_t;
using distance_t = typename IndexT::distance_t;

IndexT* index; ///! the sub-index
bool own_fields; ///! whether pointers are deleted in destructo
IndexT* index = nullptr; ///! the sub-index
bool own_fields = false; ///! whether pointers are deleted in destructo
std::vector<idx_t> id_map;

explicit IndexIDMapTemplate(IndexT* index);
Expand Down
45 changes: 21 additions & 24 deletions faiss/IndexReplicas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,34 @@

namespace faiss {

namespace {

// IndexBinary needs to update the code_size when d is set...

void sync_d(Index* index) {}

void sync_d(IndexBinary* index) {
FAISS_THROW_IF_NOT(index->d % 8 == 0);
index->code_size = index->d / 8;
}

} // anonymous namespace

template <typename IndexT>
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(bool threaded)
: ThreadedIndex<IndexT>(threaded) {}

template <typename IndexT>
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(idx_t d, bool threaded)
: ThreadedIndex<IndexT>(d, threaded) {}
: ThreadedIndex<IndexT>(d, threaded) {
sync_d(this);
}

template <typename IndexT>
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(int d, bool threaded)
: ThreadedIndex<IndexT>(d, threaded) {}
: ThreadedIndex<IndexT>(d, threaded) {
sync_d(this);
}

template <typename IndexT>
void IndexReplicasTemplate<IndexT>::onAfterAddIndex(IndexT* index) {
Expand Down Expand Up @@ -168,6 +185,8 @@ void IndexReplicasTemplate<IndexT>::syncWithSubIndexes() {
}

auto firstIndex = this->at(0);
this->d = firstIndex->d;
sync_d(this);
this->metric_type = firstIndex->metric_type;
this->is_trained = firstIndex->is_trained;
this->ntotal = firstIndex->ntotal;
Expand All @@ -181,28 +200,6 @@ void IndexReplicasTemplate<IndexT>::syncWithSubIndexes() {
}
}

// No metric_type for IndexBinary
template <>
void IndexReplicasTemplate<IndexBinary>::syncWithSubIndexes() {
if (!this->count()) {
this->is_trained = false;
this->ntotal = 0;

return;
}

auto firstIndex = this->at(0);
this->is_trained = firstIndex->is_trained;
this->ntotal = firstIndex->ntotal;

for (int i = 1; i < this->count(); ++i) {
auto index = this->at(i);
FAISS_THROW_IF_NOT(this->d == index->d);
FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
FAISS_THROW_IF_NOT(this->ntotal == index->ntotal);
}
}

// explicit instantiations
template struct IndexReplicasTemplate<Index>;
template struct IndexReplicasTemplate<IndexBinary>;
Expand Down
Loading

0 comments on commit a90d51e

Please sign in to comment.