Skip to content

Commit

Permalink
Add reverse factory string util, add StringIOReader, add centralized …
Browse files Browse the repository at this point in the history
…JK (#3879)

Summary:

1. Adds JK `faiss/telemetry:use_faiss_telemetry_core` to the top level logging util in `wrapper_logging_utils.h`. This is currently set to false. I plan to deprecate the other knobs under https://www.internalfb.com/intern/justknobs/?name=faiss%2Ftelemetry and just use one, as Unicorn can't really have their own JK easily (they subclass a lot of FAISS classes too).
2. Copied StringIOReader from Unicorn to telemetry wrapper in `io.h`. This will be deleted from Unicorn in the follow up diff.
3. Updated Laser tests to reflect correct index_read factory string changes.
4. Adds reverse_index_factory. More tests for it in subsequent diff.

Differential Revision: D62670316
  • Loading branch information
Michael Norris authored and facebook-github-bot committed Sep 21, 2024
1 parent 03f1d2a commit 48049bd
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 0 deletions.
152 changes: 152 additions & 0 deletions faiss/cppcontrib/factory_tools.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#include <faiss/cppcontrib/factory_tools.h>
#include <map>

namespace faiss {

namespace {

const std::map<faiss::ScalarQuantizer::QuantizerType, std::string> sq_types = {
{faiss::ScalarQuantizer::QT_8bit, "SQ8"},
{faiss::ScalarQuantizer::QT_4bit, "SQ4"},
{faiss::ScalarQuantizer::QT_6bit, "SQ6"},
{faiss::ScalarQuantizer::QT_fp16, "SQfp16"},
{faiss::ScalarQuantizer::QT_bf16, "SQbf16"},
{faiss::ScalarQuantizer::QT_8bit_direct_signed, "SQ8_direct_signed"},
{faiss::ScalarQuantizer::QT_8bit_direct, "SQ8_direct"},
};

int get_hnsw_M(const faiss::IndexHNSW* index) {
if (index->hnsw.cum_nneighbor_per_level.size() >= 1) {
return index->hnsw.cum_nneighbor_per_level[1] / 2;
}
// Avoid runtime error, just return 0.
return 0;
}

} // namespace

// Reference for reverse_index_factory:
// https://github.com/facebookresearch/faiss/blob/838612c9d7f2f619811434ec9209c020f44107cb/contrib/factory_tools.py#L81
std::string reverse_index_factory(const faiss::Index* index) {
std::string prefix;
if (dynamic_cast<const faiss::IndexFlat*>(index)) {
return "Flat";
} else if (
const faiss::IndexIVF* ivf_index =
dynamic_cast<const faiss::IndexIVF*>(index)) {
const faiss::Index* quantizer = ivf_index->quantizer;

if (dynamic_cast<const faiss::IndexFlat*>(quantizer)) {
prefix = "IVF" + std::to_string(ivf_index->nlist);
} else if (
const faiss::MultiIndexQuantizer* miq =
dynamic_cast<const faiss::MultiIndexQuantizer*>(
quantizer)) {
prefix = "IMI" + std::to_string(miq->pq.M) + "x" +
std::to_string(miq->pq.nbits);
} else if (
const faiss::IndexHNSW* hnsw_index =
dynamic_cast<const faiss::IndexHNSW*>(quantizer)) {
prefix = "IVF" + std::to_string(ivf_index->nlist) + "_HNSW" +
std::to_string(get_hnsw_M(hnsw_index));
} else {
prefix = "IVF" + std::to_string(ivf_index->nlist) + "(" +
reverse_index_factory(quantizer) + ")";
}

if (dynamic_cast<const faiss::IndexIVFFlat*>(ivf_index)) {
return prefix + ",Flat";
} else if (
auto sq_index =
dynamic_cast<const faiss::IndexIVFScalarQuantizer*>(
ivf_index)) {
return prefix + "," + sq_types.at(sq_index->sq.qtype);
} else if (
const faiss::IndexIVFPQ* ivfpq_index =
dynamic_cast<const faiss::IndexIVFPQ*>(ivf_index)) {
return prefix + ",PQ" + std::to_string(ivfpq_index->pq.M) + "x" +
std::to_string(ivfpq_index->pq.nbits);
} else if (
const faiss::IndexIVFPQFastScan* ivfpqfs_index =
dynamic_cast<const faiss::IndexIVFPQFastScan*>(
ivf_index)) {
return prefix + ",PQ" + std::to_string(ivfpqfs_index->pq.M) + "x" +
std::to_string(ivfpqfs_index->pq.nbits) + "fs";
}
} else if (
const faiss::IndexPreTransform* pretransform_index =
dynamic_cast<const faiss::IndexPreTransform*>(index)) {
if (pretransform_index->chain.size() != 1) {
// Avoid runtime error, just return empty string for logging.
return "";
}
const faiss::VectorTransform* vt = pretransform_index->chain.at(0);
if (const faiss::OPQMatrix* opq_matrix =
dynamic_cast<const faiss::OPQMatrix*>(vt)) {
prefix = "OPQ" + std::to_string(opq_matrix->M) + "_" +
std::to_string(opq_matrix->d_out);
} else if (
const faiss::ITQTransform* itq_transform =
dynamic_cast<const faiss::ITQTransform*>(vt)) {
prefix = "ITQ" + std::to_string(itq_transform->itq.d_out);
} else if (
const faiss::PCAMatrix* pca_matrix =
dynamic_cast<const faiss::PCAMatrix*>(vt)) {
assert(pca_matrix->eigen_power == 0);
prefix = "PCA" +
std::string(pca_matrix->random_rotation ? "R" : "") +
std::to_string(pca_matrix->d_out);
} else {
// Avoid runtime error, just return empty string for logging.
return "";
}
return prefix + "," + reverse_index_factory(pretransform_index->index);
} else if (
const faiss::IndexHNSW* hnsw_index =
dynamic_cast<const faiss::IndexHNSW*>(index)) {
return "HNSW" + std::to_string(get_hnsw_M(hnsw_index));
} else if (
const faiss::IndexRefine* refine_index =
dynamic_cast<const faiss::IndexRefine*>(index)) {
return reverse_index_factory(refine_index->base_index) + ",Refine(" +
reverse_index_factory(refine_index->refine_index) + ")";
} else if (
const faiss::IndexPQFastScan* pqfs_index =
dynamic_cast<const faiss::IndexPQFastScan*>(index)) {
return std::string("PQ") + std::to_string(pqfs_index->pq.M) + "x" +
std::to_string(pqfs_index->pq.nbits) + "fs";
} else if (
const faiss::IndexPQ* pq_index =
dynamic_cast<const faiss::IndexPQ*>(index)) {
return std::string("PQ") + std::to_string(pq_index->pq.M) + "x" +
std::to_string(pq_index->pq.nbits);
} else if (
const faiss::IndexLSH* lsh_index =
dynamic_cast<const faiss::IndexLSH*>(index)) {
std::string result = "LSH";
if (lsh_index->rotate_data) {
result += "r";
}
if (lsh_index->train_thresholds) {
result += "t";
}
return result;
} else if (
const faiss::IndexScalarQuantizer* sq_index =
dynamic_cast<const faiss::IndexScalarQuantizer*>(index)) {
return std::string("SQ") + sq_types.at(sq_index->sq.qtype);
}
// Avoid runtime error, just return empty string for logging.
return "";
}

} // namespace faiss
24 changes: 24 additions & 0 deletions faiss/cppcontrib/factory_tools.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#pragma once

#include <faiss/IndexHNSW.h>
#include <faiss/IndexIVFFlat.h>
#include <faiss/IndexIVFPQFastScan.h>
#include <faiss/IndexLSH.h>
#include <faiss/IndexPQFastScan.h>
#include <faiss/IndexPreTransform.h>
#include <faiss/IndexRefine.h>

namespace faiss {

std::string reverse_index_factory(const faiss::Index* index);

} // namespace faiss
54 changes: 54 additions & 0 deletions tests/test_factory_tools.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <faiss/cppcontrib/factory_tools.h>
#include <faiss/index_factory.h>
#include <gtest/gtest.h>

using namespace faiss;

TEST(TestFactoryTools, TestReverseIndexFactory) {
auto factory_string = "Flat";
auto index = faiss::index_factory(64, factory_string);
EXPECT_EQ(factory_string, reverse_index_factory(index));
delete index;

factory_string = "IMI2x5,PQ8x8";
index = faiss::index_factory(32, factory_string);
EXPECT_EQ(factory_string, reverse_index_factory(index));
delete index;

factory_string = "IVF32_HNSW32,SQ8";
index = faiss::index_factory(64, factory_string);
EXPECT_EQ(factory_string, reverse_index_factory(index));
delete index;

factory_string = "IVF8,Flat";
index = faiss::index_factory(64, factory_string);
EXPECT_EQ(factory_string, reverse_index_factory(index));
delete index;

factory_string = "IVF8,SQ4";
index = faiss::index_factory(64, factory_string);
EXPECT_EQ(factory_string, reverse_index_factory(index));
delete index;

factory_string = "IVF8,PQ4x8";
index = faiss::index_factory(64, factory_string);
EXPECT_EQ(factory_string, reverse_index_factory(index));
delete index;

factory_string = "LSHrt";
index = faiss::index_factory(64, factory_string);
EXPECT_EQ(factory_string, reverse_index_factory(index));
delete index;

factory_string = "PQ4x8";
index = faiss::index_factory(64, factory_string);
EXPECT_EQ(factory_string, reverse_index_factory(index));
delete index;

factory_string = "HNSW32";
index = faiss::index_factory(64, factory_string);
EXPECT_EQ(factory_string, reverse_index_factory(index));
delete index;
}

0 comments on commit 48049bd

Please sign in to comment.