Skip to content

Commit

Permalink
Support independent IVF coarse quantizer
Browse files Browse the repository at this point in the history
Summary: In the IndexIVFIndepenentQuantizer, the coarse quantizer is applied on the input vectors, but the encoding is performed on a vector-transformed version of the database elements.

Reviewed By: alexanderguzhva

Differential Revision: D45950970

fbshipit-source-id: 30f6cf46d44174b1d99a12384b7d5e2d475c1f88
  • Loading branch information
mdouze authored and facebook-github-bot committed May 26, 2023
1 parent a3296f4 commit 6800ebe
Show file tree
Hide file tree
Showing 14 changed files with 448 additions and 55 deletions.
14 changes: 14 additions & 0 deletions contrib/inspect_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ def get_LinearTransform_matrix(pca):
return A, b


def make_LinearTransform_matrix(A, b=None):
""" make a linear transform from a matrix and a bias term (optional)"""
d_out, d_in = A.shape
if b is not None:
assert b.shape == (d_out, )
lt = faiss.LinearTransform(d_in, d_out, b is not None)
faiss.copy_array_to_vector(A.ravel(), lt.A)
if b is not None:
faiss.copy_array_to_vector(b, lt.b)
lt.is_trained = True
lt.set_is_orthonormal()
return lt


def get_additive_quantizer_codebooks(aq):
""" return to codebooks of an additive quantizer """
codebooks = faiss.vector_to_array(aq.codebooks).reshape(-1, aq.d)
Expand Down
2 changes: 2 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ set(FAISS_SRC
IndexPQ.cpp
IndexFastScan.cpp
IndexAdditiveQuantizerFastScan.cpp
IndexIVFIndependentQuantizer.cpp
IndexPQFastScan.cpp
IndexPreTransform.cpp
IndexRefine.cpp
Expand Down Expand Up @@ -113,6 +114,7 @@ set(FAISS_HEADERS
IndexIDMap.h
IndexIVF.h
IndexIVFAdditiveQuantizer.h
IndexIVFIndependentQuantizer.h
IndexIVFFlat.h
IndexIVFPQ.h
IndexIVFFastScan.h
Expand Down
5 changes: 5 additions & 0 deletions faiss/IVFlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <faiss/IndexAdditiveQuantizer.h>
#include <faiss/IndexIVFAdditiveQuantizer.h>
#include <faiss/IndexIVFIndependentQuantizer.h>
#include <faiss/IndexPreTransform.h>
#include <faiss/MetaIndexes.h>
#include <faiss/impl/FaissAssert.h>
Expand Down Expand Up @@ -67,6 +68,10 @@ const IndexIVF* try_extract_index_ivf(const Index* index) {
if (auto* idmap = dynamic_cast<const IndexIDMap2*>(index)) {
index = idmap->index;
}
if (auto* indep =
dynamic_cast<const IndexIVFIndependentQuantizer*>(index)) {
index = indep->index_ivf;
}

auto* ivf = dynamic_cast<const IndexIVF*>(index);

Expand Down
172 changes: 172 additions & 0 deletions faiss/IndexIVFIndependentQuantizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <faiss/IndexIVFIndependentQuantizer.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/utils.h>

namespace faiss {

IndexIVFIndependentQuantizer::IndexIVFIndependentQuantizer(
Index* quantizer,
IndexIVF* index_ivf,
VectorTransform* vt)
: Index(quantizer->d, index_ivf->metric_type),
quantizer(quantizer),
vt(vt),
index_ivf(index_ivf) {
if (vt) {
FAISS_THROW_IF_NOT_MSG(
vt->d_in == d && vt->d_out == index_ivf->d,
"invalid vector dimensions");
} else {
FAISS_THROW_IF_NOT_MSG(index_ivf->d == d, "invalid vector dimensions");
}

if (quantizer->is_trained && quantizer->ntotal != 0) {
FAISS_THROW_IF_NOT(quantizer->ntotal == index_ivf->nlist);
}
if (index_ivf->is_trained && vt) {
FAISS_THROW_IF_NOT(vt->is_trained);
}
ntotal = index_ivf->ntotal;
is_trained =
(quantizer->is_trained && quantizer->ntotal == index_ivf->nlist &&
(!vt || vt->is_trained) && index_ivf->is_trained);

// disable precomputed tables because they use the distances that are
// provided by the coarse quantizer (that are out of sync with the IVFPQ)
if (auto index_ivfpq = dynamic_cast<IndexIVFPQ*>(index_ivf)) {
index_ivfpq->use_precomputed_table = -1;
}
}

IndexIVFIndependentQuantizer::~IndexIVFIndependentQuantizer() {
if (own_fields) {
delete quantizer;
delete index_ivf;
delete vt;
}
}

namespace {

struct VTransformedVectors : TransformedVectors {
VTransformedVectors(const VectorTransform* vt, idx_t n, const float* x)
: TransformedVectors(x, vt ? vt->apply(n, x) : x) {}
};

struct SubsampledVectors : TransformedVectors {
SubsampledVectors(int d, idx_t* n, idx_t max_n, const float* x)
: TransformedVectors(
x,
fvecs_maybe_subsample(d, (size_t*)n, max_n, x, true)) {}
};

} // anonymous namespace

void IndexIVFIndependentQuantizer::add(idx_t n, const float* x) {
std::vector<float> D(n);
std::vector<idx_t> I(n);
quantizer->search(n, x, 1, D.data(), I.data());

VTransformedVectors tv(vt, n, x);

index_ivf->add_core(n, tv.x, nullptr, I.data());
}

void IndexIVFIndependentQuantizer::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(!params, "search parameters not supported");
int nprobe = index_ivf->nprobe;
std::vector<float> D(n * nprobe);
std::vector<idx_t> I(n * nprobe);
quantizer->search(n, x, nprobe, D.data(), I.data());

VTransformedVectors tv(vt, n, x);

index_ivf->search_preassigned(
n, tv.x, k, I.data(), D.data(), distances, labels, false);
}

void IndexIVFIndependentQuantizer::reset() {
index_ivf->reset();
ntotal = 0;
}

void IndexIVFIndependentQuantizer::train(idx_t n, const float* x) {
// quantizer training
size_t nlist = index_ivf->nlist;
Level1Quantizer l1(quantizer, nlist);
l1.train_q1(n, x, verbose, metric_type);

// train the VectorTransform
if (vt && !vt->is_trained) {
if (verbose) {
printf("IndexIVFIndependentQuantizer: train the VectorTransform\n");
}
vt->train(n, x);
}

// get the centroids from the quantizer, transform them and
// add them to the index_ivf's quantizer
if (verbose) {
printf("IndexIVFIndependentQuantizer: extract the main quantizer centroids\n");
}
std::vector<float> centroids(nlist * d);
quantizer->reconstruct_n(0, nlist, centroids.data());
VTransformedVectors tcent(vt, nlist, centroids.data());

if (verbose) {
printf("IndexIVFIndependentQuantizer: add centroids to the secondary quantizer\n");
}
if (!index_ivf->quantizer->is_trained) {
index_ivf->quantizer->train(nlist, tcent.x);
}
index_ivf->quantizer->add(nlist, tcent.x);

// train the payload

// optional subsampling
idx_t max_nt = index_ivf->train_encoder_num_vectors();
if (max_nt <= 0) {
max_nt = (size_t)1 << 35;
}
SubsampledVectors sv(index_ivf->d, &n, max_nt, x);

// transform subsampled vectors
VTransformedVectors tv(vt, n, sv.x);

if (verbose) {
printf("IndexIVFIndependentQuantizer: train encoder\n");
}

if (index_ivf->by_residual) {
// assign with quantizer
std::vector<idx_t> assign(n);
quantizer->assign(n, sv.x, assign.data());

// compute residual with IVF quantizer
std::vector<float> residuals(n * index_ivf->d);
index_ivf->quantizer->compute_residual_n(
n, tv.x, residuals.data(), assign.data());

index_ivf->train_encoder(n, residuals.data(), assign.data());
} else {
index_ivf->train_encoder(n, tv.x, nullptr);
}
index_ivf->is_trained = true;
is_trained = true;
}

} // namespace faiss
56 changes: 56 additions & 0 deletions faiss/IndexIVFIndependentQuantizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <faiss/IndexIVF.h>
#include <faiss/VectorTransform.h>

namespace faiss {

/** An IVF index with a quantizer that has a different input dimension from the
* payload size. The vectors to encode are obtained from the input vectors by a
* VectorTransform.
*/
struct IndexIVFIndependentQuantizer : Index {
/// quantizer is fed directly with the input vectors
Index* quantizer = nullptr;

/// transform before the IVF vectors are applied
VectorTransform* vt = nullptr;

/// the IVF index, controls nlist and nprobe
IndexIVF* index_ivf = nullptr;

/// whether *this owns the 3 fields
bool own_fields = false;

IndexIVFIndependentQuantizer(
Index* quantizer,
IndexIVF* index_ivf,
VectorTransform* vt = nullptr);

IndexIVFIndependentQuantizer() {}

void train(idx_t n, const float* x) override;

void add(idx_t n, const float* x) override;

void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params = nullptr) const override;

void reset() override;

~IndexIVFIndependentQuantizer() override;
};

} // namespace faiss
29 changes: 10 additions & 19 deletions faiss/IndexPreTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,8 @@ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)

void IndexPreTransform::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT(is_trained);
const float* xt = apply_chain(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
index->add(n, xt);
TransformedVectors tv(x, apply_chain(n, x));
index->add(n, tv.x);
ntotal = index->ntotal;
}

Expand All @@ -152,9 +151,8 @@ void IndexPreTransform::add_with_ids(
const float* x,
const idx_t* xids) {
FAISS_THROW_IF_NOT(is_trained);
const float* xt = apply_chain(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
index->add_with_ids(n, xt, xids);
TransformedVectors tv(x, apply_chain(n, x));
index->add_with_ids(n, tv.x, xids);
ntotal = index->ntotal;
}

Expand Down Expand Up @@ -190,10 +188,9 @@ void IndexPreTransform::range_search(
RangeSearchResult* result,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT(is_trained);
const float* xt = apply_chain(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
TransformedVectors tv(x, apply_chain(n, x));
index->range_search(
n, xt, radius, result, extract_index_search_params(params));
n, tv.x, radius, result, extract_index_search_params(params));
}

void IndexPreTransform::reset() {
Expand Down Expand Up @@ -238,14 +235,13 @@ void IndexPreTransform::search_and_reconstruct(
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(is_trained);

const float* xt = apply_chain(n, x);
ScopeDeleter<float> del((xt == x) ? nullptr : xt);
TransformedVectors trans(x, apply_chain(n, x));

float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
index->search_and_reconstruct(
n,
xt,
trans.x,
k,
distances,
labels,
Expand All @@ -262,13 +258,8 @@ size_t IndexPreTransform::sa_code_size() const {

void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
const {
if (chain.empty()) {
index->sa_encode(n, x, bytes);
} else {
const float* xt = apply_chain(n, x);
ScopeDeleter<float> del(xt == x ? nullptr : xt);
index->sa_encode(n, xt, bytes);
}
TransformedVectors tv(x, apply_chain(n, x));
index->sa_encode(n, tv.x, bytes);
}

void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
Expand Down
Loading

0 comments on commit 6800ebe

Please sign in to comment.