Skip to content

Commit

Permalink
Allow k and M suffixes in IVF indexes (#3812)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3812

Allows factory strings like `IVF3k,Flat` as a shorthand for 3072 centroids.

The main question is whether k or M should be metric (k=1000) or power of 2 (k=1024):

* pro-metric: standard,

* pro-power of 2: in practice we use powers of 2 most often

The suffixes ki and Mi should be used for powers of 2 but this makes the notation more heavy (which is what we wanted to avoid in the first place).

So I picked power of 2.

Reviewed By: mnorris11

Differential Revision: D62019941

fbshipit-source-id: f547962625123ecdfaa406067781c77386017793
  • Loading branch information
mdouze authored and facebook-github-bot committed Sep 10, 2024
1 parent 6fe4640 commit d85fda7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
29 changes: 21 additions & 8 deletions faiss/index_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,19 @@ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
* Parse IndexIVF
*/

size_t parse_nlist(std::string s) {
size_t multiplier = 1;
if (s.back() == 'k') {
s.pop_back();
multiplier = 1024;
}
if (s.back() == 'M') {
s.pop_back();
multiplier = 1024 * 1024;
}
return std::stoi(s) * multiplier;
}

// parsing guard + function
Index* parse_coarse_quantizer(
const std::string& description,
Expand All @@ -240,8 +253,8 @@ Index* parse_coarse_quantizer(
};
use_2layer = false;

if (match("IVF([0-9]+)")) {
nlist = std::stoi(sm[1].str());
if (match("IVF([0-9]+[kM]?)")) {
nlist = parse_nlist(sm[1].str());
return new IndexFlat(d, mt);
}
if (match("IMI2x([0-9]+)")) {
Expand All @@ -252,18 +265,18 @@ Index* parse_coarse_quantizer(
nlist = (size_t)1 << (2 * nbit);
return new MultiIndexQuantizer(d, 2, nbit);
}
if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
nlist = std::stoi(sm[1].str());
if (match("IVF([0-9]+[kM]?)_HNSW([0-9]*)")) {
nlist = parse_nlist(sm[1].str());
int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
return new IndexHNSWFlat(d, hnsw_M, mt);
}
if (match("IVF([0-9]+)_NSG([0-9]+)")) {
nlist = std::stoi(sm[1].str());
if (match("IVF([0-9]+[kM]?)_NSG([0-9]+)")) {
nlist = parse_nlist(sm[1].str());
int R = std::stoi(sm[2]);
return new IndexNSGFlat(d, R, mt);
}
if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
nlist = std::stoi(sm[1].str());
if (match("IVF([0-9]+[kM]?)\\(Index([0-9])\\)")) {
nlist = parse_nlist(sm[1].str());
int no = std::stoi(sm[2].str());
FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
return parenthesis_indexes[no].release();
Expand Down
12 changes: 12 additions & 0 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,18 @@ def test_ivf(self):
index = faiss.index_factory(123, "IVF456,Flat")
self.assertEqual(index.__class__, faiss.IndexIVFFlat)

def test_ivf_suffix_k(self):
index = faiss.index_factory(123, "IVF3k,Flat")
self.assertEqual(index.nlist, 3072)

def test_ivf_suffix_M(self):
index = faiss.index_factory(123, "IVF1M,Flat")
self.assertEqual(index.nlist, 1024 * 1024)

def test_ivf_suffix_HNSW_M(self):
index = faiss.index_factory(123, "IVF1M_HNSW,Flat")
self.assertEqual(index.nlist, 1024 * 1024)

def test_idmap(self):
index = faiss.index_factory(123, "Flat,IDMap")
self.assertEqual(index.__class__, faiss.IndexIDMap)
Expand Down

0 comments on commit d85fda7

Please sign in to comment.