Skip to content

Commit

Permalink
make nbits configurable for graph indices based on PQ
Browse files Browse the repository at this point in the history
Summary:
As requested in

#3027

Indeed, PQ sizes with nbits > 8 are good tradeoffs, so it is interesting to support them.

Differential Revision: D48860659

fbshipit-source-id: 4956fdf5a442dae0a206478e211a72108c85c284
  • Loading branch information
mdouze authored and facebook-github-bot committed Aug 31, 2023
1 parent 3888f9b commit 0fed833
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 51 deletions.
16 changes: 4 additions & 12 deletions faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,18 +250,10 @@ void hnsw_add_vertices(
**************************************************************/

IndexHNSW::IndexHNSW(int d, int M, MetricType metric)
: Index(d, metric),
hnsw(M),
own_fields(false),
storage(nullptr),
reconstruct_from_neighbors(nullptr) {}
: Index(d, metric), hnsw(M) {}

IndexHNSW::IndexHNSW(Index* storage, int M)
: Index(storage->d, storage->metric_type),
hnsw(M),
own_fields(false),
storage(storage),
reconstruct_from_neighbors(nullptr) {}
: Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {}

IndexHNSW::~IndexHNSW() {
if (own_fields) {
Expand Down Expand Up @@ -886,8 +878,8 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)

IndexHNSWPQ::IndexHNSWPQ() {}

IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M)
: IndexHNSW(new IndexPQ(d, pq_m, 8), M) {
IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits)
: IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) {
own_fields = true;
is_trained = false;
}
Expand Down
8 changes: 4 additions & 4 deletions faiss/IndexHNSW.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ struct IndexHNSW : Index {
HNSW hnsw;

// the sequential storage
bool own_fields;
Index* storage;
bool own_fields = false;
Index* storage = nullptr;

ReconstructFromNeighbors* reconstruct_from_neighbors;
ReconstructFromNeighbors* reconstruct_from_neighbors = nullptr;

explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2);
explicit IndexHNSW(Index* storage, int M = 32);
Expand Down Expand Up @@ -152,7 +152,7 @@ struct IndexHNSWFlat : IndexHNSW {
*/
struct IndexHNSWPQ : IndexHNSW {
IndexHNSWPQ();
IndexHNSWPQ(int d, int pq_m, int M);
IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8);
void train(idx_t n, const float* x) override;
};

Expand Down
22 changes: 3 additions & 19 deletions faiss/IndexNSG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,16 @@ using namespace nsg;
* IndexNSG implementation
**************************************************************/

IndexNSG::IndexNSG(int d, int R, MetricType metric)
: Index(d, metric),
nsg(R),
own_fields(false),
storage(nullptr),
is_built(false),
GK(64),
build_type(0) {
nndescent_S = 10;
nndescent_R = 100;
IndexNSG::IndexNSG(int d, int R, MetricType metric) : Index(d, metric), nsg(R) {
nndescent_L = GK + 50;
nndescent_iter = 10;
}

IndexNSG::IndexNSG(Index* storage, int R)
: Index(storage->d, storage->metric_type),
nsg(R),
own_fields(false),
storage(storage),
is_built(false),
GK(64),
build_type(1) {
nndescent_S = 10;
nndescent_R = 100;
nndescent_L = GK + 50;
nndescent_iter = 10;
}

IndexNSG::~IndexNSG() {
Expand Down Expand Up @@ -304,8 +288,8 @@ IndexNSGFlat::IndexNSGFlat(int d, int R, MetricType metric)

IndexNSGPQ::IndexNSGPQ() {}

IndexNSGPQ::IndexNSGPQ(int d, int pq_m, int M)
: IndexNSG(new IndexPQ(d, pq_m, 8), M) {
IndexNSGPQ::IndexNSGPQ(int d, int pq_m, int M, int pq_nbits)
: IndexNSG(new IndexPQ(d, pq_m, pq_nbits), M) {
own_fields = true;
is_trained = false;
}
Expand Down
20 changes: 10 additions & 10 deletions faiss/IndexNSG.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,25 @@ struct IndexNSG : Index {
NSG nsg;

/// the sequential storage
bool own_fields;
Index* storage;
bool own_fields = false;
Index* storage = nullptr;

/// the index is built or not
bool is_built;
bool is_built = false;

/// K of KNN graph for building
int GK;
int GK = 64;

/// indicate how to build a knn graph
/// - 0: build NSG with brute force search
/// - 1: build NSG with NNDescent
char build_type;
char build_type = 0;

/// parameters for nndescent
int nndescent_S;
int nndescent_R;
int nndescent_L;
int nndescent_iter;
int nndescent_S = 10;
int nndescent_R = 100;
int nndescent_L; // set to GK + 50
int nndescent_iter = 10;

explicit IndexNSG(int d = 0, int R = 32, MetricType metric = METRIC_L2);
explicit IndexNSG(Index* storage, int R = 32);
Expand Down Expand Up @@ -90,7 +90,7 @@ struct IndexNSGFlat : IndexNSG {
*/
struct IndexNSGPQ : IndexNSG {
IndexNSGPQ();
IndexNSGPQ(int d, int pq_m, int M);
IndexNSGPQ(int d, int pq_m, int M, int pq_nbits = 8);
void train(idx_t n, const float* x) override;
};

Expand Down
15 changes: 9 additions & 6 deletions faiss/index_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,13 @@ IndexHNSW* parse_IndexHNSW(
if (match("Flat|")) {
return new IndexHNSWFlat(d, hnsw_M, mt);
}
if (match("PQ([0-9]+)(np)?")) {

if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
int M = std::stoi(sm[1].str());
IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
int nbit = mres_to_int(sm[2], 8, 1);
IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M, nbit);
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
sm[2].str() != "np";
sm[3].str() != "np";
return ipq;
}
if (match(sq_pattern)) {
Expand Down Expand Up @@ -490,11 +492,12 @@ IndexNSG* parse_IndexNSG(
if (match("Flat|")) {
return new IndexNSGFlat(d, nsg_R, mt);
}
if (match("PQ([0-9]+)(np)?")) {
if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
int M = std::stoi(sm[1].str());
IndexNSGPQ* ipq = new IndexNSGPQ(d, M, nsg_R);
int nbit = mres_to_int(sm[2], 8, 1);
IndexNSGPQ* ipq = new IndexNSGPQ(d, M, nsg_R, nbit);
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
sm[2].str() != "np";
sm[3].str() != "np";
return ipq;
}
if (match(sq_pattern)) {
Expand Down
9 changes: 9 additions & 0 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def test_factory_HNSW_newstyle(self):
index = faiss.index_factory(12, "HNSW32,PQ4np")
indexpq = faiss.downcast_index(index.storage)
assert not indexpq.do_polysemous_training
index = faiss.index_factory(12, "HNSW32,PQ4x12np")
indexpq = faiss.downcast_index(index.storage)
self.assertEqual(indexpq.pq.nbits, 12)

def test_factory_NSG(self):
index = faiss.index_factory(12, "NSG64")
Expand All @@ -97,6 +100,12 @@ def test_factory_NSG(self):
assert isinstance(index, faiss.IndexNSGFlat)
assert index.nsg.R == 64

index = faiss.index_factory(12, "NSG64,PQ3x10")
assert isinstance(index, faiss.IndexNSGPQ)
assert index.nsg.R == 64
indexpq = faiss.downcast_index(index.storage)
self.assertEqual(indexpq.pq.nbits, 10)

index = faiss.index_factory(12, "IVF65536_NSG64,Flat")
index_nsg = faiss.downcast_index(index.quantizer)
assert isinstance(index, faiss.IndexIVFFlat)
Expand Down

0 comments on commit 0fed833

Please sign in to comment.