Skip to content

Commit

Permalink
reduce cross table size (facebookresearch#3012)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#3012

The cross-tables for codebook construction contained the dot products between codebook entries, which is not necessary (and caused OOMs in some cases). This diff computes only the off-diagonal blocks.

Reviewed By: pemazare

Differential Revision: D48448615

fbshipit-source-id: 494b54e2900754a3ff5d3c8073cb9a768e578c58
  • Loading branch information
mdouze authored and facebook-github-bot committed Sep 1, 2023
1 parent 039409d commit 9dc75d0
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 133 deletions.
31 changes: 19 additions & 12 deletions faiss/impl/ResidualQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,29 +493,36 @@ void ResidualQuantizer::refine_beam(
*******************************************************************/

void ResidualQuantizer::compute_codebook_tables() {
codebook_cross_products.resize(total_codebook_size * total_codebook_size);
cent_norms.resize(total_codebook_size);
// stricly speaking we could use ssyrk
{
FINTEGER ni = total_codebook_size;
fvec_norms_L2sqr(
cent_norms.data(), codebooks.data(), d, total_codebook_size);
size_t cross_table_size = 0;
for (int m = 0; m < M; m++) {
size_t K = (size_t)1 << nbits[m];
cross_table_size += K * codebook_offsets[m];
}
codebook_cross_products.resize(cross_table_size);
size_t ofs = 0;
for (int m = 1; m < M; m++) {
FINTEGER ki = (size_t)1 << nbits[m];
FINTEGER kk = codebook_offsets[m];
FINTEGER di = d;
float zero = 0, one = 1;
assert(ofs + ki * kk <= cross_table_size);
sgemm_("Transposed",
"Not transposed",
&ni,
&ni,
&ki,
&kk,
&di,
&one,
codebooks.data(),
codebooks.data() + d * kk,
&di,
codebooks.data(),
&di,
&zero,
codebook_cross_products.data(),
&ni);
}
for (size_t i = 0; i < total_codebook_size; i++) {
cent_norms[i] = codebook_cross_products[i + i * total_codebook_size];
codebook_cross_products.data() + ofs,
&ki);
ofs += ki * kk;
}
}

Expand Down
6 changes: 3 additions & 3 deletions faiss/impl/ResidualQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ struct ResidualQuantizer : AdditiveQuantizer {
*/
void compute_codebook_tables();

/// dot products of all codebook vectors with each other
/// size total_codebook_size * total_codebook_size
/// dot products of all codebook entries with the previous codebooks
/// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1)
std::vector<float> codebook_cross_products;
/// norms of all vectors
/// norms of all codebook entries (size total_codebook_size)
std::vector<float> cent_norms;
};

Expand Down
201 changes: 103 additions & 98 deletions faiss/impl/residual_quantizer_encode_steps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,11 @@ void beam_search_encode_step_tab(
size_t n,
size_t beam_size, // input sizes
const float* codebook_cross_norms, // size K * ldc
size_t ldc, // >= K
const uint64_t* codebook_offsets, // m
const float* query_cp, // size n * ldqc
size_t ldqc, // >= K
const float* cent_norms_i, // size K
size_t ldc,
const uint64_t* codebook_offsets, // m
const float* query_cp, // size n * ldqc
size_t ldqc, // >= K
const float* cent_norms_i, // size K
size_t m,
const int32_t* codes, // n * beam_size * m
const float* distances, // n * beam_size
Expand All @@ -412,35 +412,38 @@ void beam_search_encode_step_tab(
cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
}

/*
bool use_baseline_implementation = false;

// This is the baseline implementation. Its primary flaw
// that it writes way too many info to the temporary buffer
// called dp.
//
// This baseline code is kept intentionally because it is easy to
// understand what an optimized version optimizes exactly.
//
for (size_t b = 0; b < beam_size; b++) {
std::vector<float> dp(K);
for (size_t m1 = 0; m1 < m; m1++) {
size_t c = codes_i[b * m + m1];
const float* cb =
&codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
fvec_add(K, cb, dp.data(), dp.data());
}
if (use_baseline_implementation) {
for (size_t b = 0; b < beam_size; b++) {
std::vector<float> dp(K);

for (size_t k = 0; k < K; k++) {
cent_distances[b * K + k] =
distances_i[b] + cd_common[k] + 2 * dp[k];
for (size_t m1 = 0; m1 < m; m1++) {
size_t c = codes_i[b * m + m1];
const float* cb =
&codebook_cross_norms
[(codebook_offsets[m1] + c) * ldc];
fvec_add(K, cb, dp.data(), dp.data());
}

for (size_t k = 0; k < K; k++) {
cent_distances[b * K + k] =
distances_i[b] + cd_common[k] + 2 * dp[k];
}
}
}
*/

// An optimized implementation that avoids using a temporary buffer
// and does the accumulation in registers.
} else {
// An optimized implementation that avoids using a temporary buffer
// and does the accumulation in registers.

// Compute a sum of NK AQ codes.
// Compute a sum of NK AQ codes.
#define ACCUM_AND_FINALIZE_TAB(NK) \
case NK: \
for (size_t b = 0; b < beam_size; b++) { \
Expand All @@ -457,51 +460,52 @@ void beam_search_encode_step_tab(
} \
break;

// this version contains many switch-case scenarios, but
// they won't affect branch predictor.
switch (m) {
case 0:
// trivial case
for (size_t b = 0; b < beam_size; b++) {
for (size_t k = 0; k < K; k++) {
cent_distances[b * K + k] =
distances_i[b] + cd_common[k];
// this version contains many switch-case scenarios, but
// they won't affect branch predictor.
switch (m) {
case 0:
// trivial case
for (size_t b = 0; b < beam_size; b++) {
for (size_t k = 0; k < K; k++) {
cent_distances[b * K + k] =
distances_i[b] + cd_common[k];
}
}
}
break;

ACCUM_AND_FINALIZE_TAB(1)
ACCUM_AND_FINALIZE_TAB(2)
ACCUM_AND_FINALIZE_TAB(3)
ACCUM_AND_FINALIZE_TAB(4)
ACCUM_AND_FINALIZE_TAB(5)
ACCUM_AND_FINALIZE_TAB(6)
ACCUM_AND_FINALIZE_TAB(7)

default: {
// m >= 8 case.

// A temporary buffer has to be used due to the lack of
// registers. But we'll try to accumulate up to 8 AQ codes in
// registers and issue a single write operation to the buffer,
// while the baseline does no accumulation. So, the number of
// write operations to the temporary buffer is reduced 8x.

// allocate a temporary buffer
std::vector<float> dp(K);

for (size_t b = 0; b < beam_size; b++) {
// Initialize it. Compute a sum of first 8 AQ codes
// because m >= 8 .
accum_and_store_tab<8, 4>(
m,
codebook_cross_norms,
codebook_offsets,
codes_i,
b,
ldc,
K,
dp.data());
break;

ACCUM_AND_FINALIZE_TAB(1)
ACCUM_AND_FINALIZE_TAB(2)
ACCUM_AND_FINALIZE_TAB(3)
ACCUM_AND_FINALIZE_TAB(4)
ACCUM_AND_FINALIZE_TAB(5)
ACCUM_AND_FINALIZE_TAB(6)
ACCUM_AND_FINALIZE_TAB(7)

default: {
// m >= 8 case.

// A temporary buffer has to be used due to the lack of
// registers. But we'll try to accumulate up to 8 AQ codes
// in registers and issue a single write operation to the
// buffer, while the baseline does no accumulation. So, the
// number of write operations to the temporary buffer is
// reduced 8x.

// allocate a temporary buffer
std::vector<float> dp(K);

for (size_t b = 0; b < beam_size; b++) {
// Initialize it. Compute a sum of first 8 AQ codes
// because m >= 8 .
accum_and_store_tab<8, 4>(
m,
codebook_cross_norms,
codebook_offsets,
codes_i,
b,
ldc,
K,
dp.data());

#define ACCUM_AND_ADD_TAB(NK) \
case NK: \
Expand All @@ -516,37 +520,37 @@ void beam_search_encode_step_tab(
dp.data()); \
break;

// accumulate up to 8 additional AQ codes into
// a temporary buffer
for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
size_t m_left = m - im;
if (m_left > 8) {
m_left = 8;
// accumulate up to 8 additional AQ codes into
// a temporary buffer
for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
size_t m_left = m - im;
if (m_left > 8) {
m_left = 8;
}

switch (m_left) {
ACCUM_AND_ADD_TAB(1)
ACCUM_AND_ADD_TAB(2)
ACCUM_AND_ADD_TAB(3)
ACCUM_AND_ADD_TAB(4)
ACCUM_AND_ADD_TAB(5)
ACCUM_AND_ADD_TAB(6)
ACCUM_AND_ADD_TAB(7)
ACCUM_AND_ADD_TAB(8)
}
}

switch (m_left) {
ACCUM_AND_ADD_TAB(1)
ACCUM_AND_ADD_TAB(2)
ACCUM_AND_ADD_TAB(3)
ACCUM_AND_ADD_TAB(4)
ACCUM_AND_ADD_TAB(5)
ACCUM_AND_ADD_TAB(6)
ACCUM_AND_ADD_TAB(7)
ACCUM_AND_ADD_TAB(8)
// done. finalize the result
for (size_t k = 0; k < K; k++) {
cent_distances[b * K + k] =
distances_i[b] + cd_common[k] + 2 * dp[k];
}
}

// done. finalize the result
for (size_t k = 0; k < K; k++) {
cent_distances[b * K + k] =
distances_i[b] + cd_common[k] + 2 * dp[k];
}
}
}
}

// the optimized implementation ends here

// the optimized implementation ends here
}
using C = CMax<float, int>;
int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
float* new_distances_i = new_distances + i * new_beam_size;
Expand Down Expand Up @@ -784,6 +788,7 @@ void refine_beam_LUT_mp(
// main loop
size_t codes_size = 0;
size_t distances_size = 0;
size_t cross_ofs = 0;
for (int m = 0; m < rq.M; m++) {
int K = 1 << rq.nbits[m];

Expand All @@ -792,13 +797,15 @@ void refine_beam_LUT_mp(

codes_size = n * new_beam_size * (m + 1);
distances_size = n * new_beam_size;

FAISS_THROW_IF_NOT(
cross_ofs + rq.codebook_offsets[m] * K <=
rq.codebook_cross_products.size());
beam_search_encode_step_tab(
K,
n,
beam_size,
rq.codebook_cross_products.data() + rq.codebook_offsets[m],
rq.total_codebook_size,
rq.codebook_cross_products.data() + cross_ofs,
K,
rq.codebook_offsets.data(),
query_cp + rq.codebook_offsets[m],
rq.total_codebook_size,
Expand All @@ -810,7 +817,7 @@ void refine_beam_LUT_mp(
new_codes_ptr,
new_distances_ptr,
rq.approx_topk_mode);

cross_ofs += rq.codebook_offsets[m] * K;
std::swap(codes_ptr, new_codes_ptr);
std::swap(distances_ptr, new_distances_ptr);

Expand All @@ -830,7 +837,6 @@ void refine_beam_LUT_mp(
beam_size);
}
}

if (out_codes) {
memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
}
Expand Down Expand Up @@ -903,8 +909,7 @@ void compute_codes_add_centroids_mp_lut1(
pool.distances.resize(rq.max_beam_size * n);

FAISS_THROW_IF_NOT_MSG(
rq.codebook_cross_products.size() ==
rq.total_codebook_size * rq.total_codebook_size,
rq.M == 1 || rq.codebook_cross_products.size() > 0,
"call compute_codebook_tables first");

pool.query_norms.resize(n);
Expand Down
4 changes: 4 additions & 0 deletions faiss/impl/residual_quantizer_encode_steps.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ void beam_search_encode_step_tab(

/********************************************************************
* Multiple encoding steps
*
* The following functions take buffer objects that they use as temp
* memory (allocated within the functions). The buffers are intended
* to be re-used over batches of points to encode.
********************************************************************/

struct ResidualQuantizer;
Expand Down
Loading

0 comments on commit 9dc75d0

Please sign in to comment.