-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
faiss_hnsw support INT8 #991
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -292,7 +292,7 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { | |
}; | ||
|
||
// | ||
enum class DataFormatEnum { fp32, fp16, bf16 }; | ||
enum class DataFormatEnum { fp32, fp16, bf16, int8 }; | ||
|
||
template <typename T> | ||
struct DataType2EnumHelper {}; | ||
|
@@ -309,14 +309,16 @@ template <> | |
struct DataType2EnumHelper<knowhere::bf16> { | ||
static constexpr DataFormatEnum value = DataFormatEnum::bf16; | ||
}; | ||
template <> | ||
struct DataType2EnumHelper<knowhere::int8> { | ||
static constexpr DataFormatEnum value = DataFormatEnum::int8; | ||
}; | ||
|
||
template <typename T> | ||
static constexpr DataFormatEnum datatype_v = DataType2EnumHelper<T>::value; | ||
|
||
// | ||
namespace { | ||
|
||
// | ||
bool | ||
convert_rows_to_fp32(const void* const __restrict src_in, float* const __restrict dst, | ||
const DataFormatEnum src_data_format, const size_t start_row, const size_t nrows, | ||
|
@@ -326,21 +328,24 @@ convert_rows_to_fp32(const void* const __restrict src_in, float* const __restric | |
for (size_t i = 0; i < nrows * dim; i++) { | ||
dst[i] = (float)(src[i + start_row * dim]); | ||
} | ||
|
||
return true; | ||
} else if (src_data_format == DataFormatEnum::bf16) { | ||
const knowhere::bf16* const src = reinterpret_cast<const knowhere::bf16*>(src_in); | ||
for (size_t i = 0; i < nrows * dim; i++) { | ||
dst[i] = (float)(src[i + start_row * dim]); | ||
} | ||
|
||
return true; | ||
} else if (src_data_format == DataFormatEnum::fp32) { | ||
const knowhere::fp32* const src = reinterpret_cast<const knowhere::fp32*>(src_in); | ||
for (size_t i = 0; i < nrows * dim; i++) { | ||
dst[i] = src[i + start_row * dim]; | ||
} | ||
|
||
return true; | ||
} else if (src_data_format == DataFormatEnum::int8) { | ||
const knowhere::int8* const src = reinterpret_cast<const knowhere::int8*>(src_in); | ||
for (size_t i = 0; i < nrows * dim; i++) { | ||
dst[i] = (float)(src[i + start_row * dim]); | ||
} | ||
return true; | ||
} else { | ||
// unknown | ||
|
@@ -357,21 +362,27 @@ convert_rows_from_fp32(const float* const __restrict src, void* const __restrict | |
for (size_t i = 0; i < nrows * dim; i++) { | ||
dst[i + start_row * dim] = (knowhere::fp16)src[i]; | ||
} | ||
|
||
return true; | ||
} else if (dst_data_format == DataFormatEnum::bf16) { | ||
knowhere::bf16* const dst = reinterpret_cast<knowhere::bf16*>(dst_in); | ||
for (size_t i = 0; i < nrows * dim; i++) { | ||
dst[i + start_row * dim] = (knowhere::bf16)src[i]; | ||
} | ||
|
||
return true; | ||
} else if (dst_data_format == DataFormatEnum::fp32) { | ||
knowhere::fp32* const dst = reinterpret_cast<knowhere::fp32*>(dst_in); | ||
for (size_t i = 0; i < nrows * dim; i++) { | ||
dst[i + start_row * dim] = src[i]; | ||
} | ||
|
||
return true; | ||
} else if (dst_data_format == DataFormatEnum::int8) { | ||
knowhere::int8* const dst = reinterpret_cast<knowhere::int8*>(dst_in); | ||
for (size_t i = 0; i < nrows * dim; i++) { | ||
KNOWHERE_THROW_IF_NOT_MSG(src[i] >= std::numeric_limits<knowhere::int8>::min() && | ||
src[i] <= std::numeric_limits<knowhere::int8>::max(), | ||
"convert float to int8_t overflow"); | ||
dst[i + start_row * dim] = (knowhere::int8)src[i]; | ||
} | ||
return true; | ||
} else { | ||
// unknown | ||
|
@@ -388,8 +399,9 @@ convert_ds_to_float(const DataSetPtr& src, DataFormatEnum data_format) { | |
return ConvertFromDataTypeIfNeeded<knowhere::fp16>(src); | ||
} else if (data_format == DataFormatEnum::bf16) { | ||
return ConvertFromDataTypeIfNeeded<knowhere::bf16>(src); | ||
} else if (data_format == DataFormatEnum::int8) { | ||
return ConvertFromDataTypeIfNeeded<knowhere::int8>(src); | ||
} | ||
|
||
return nullptr; | ||
} | ||
|
||
|
@@ -451,6 +463,8 @@ get_index_data_format(const faiss::Index* index) { | |
return DataFormatEnum::bf16; | ||
} else if (index_sq->sq.qtype == faiss::ScalarQuantizer::QT_fp16) { | ||
return DataFormatEnum::fp16; | ||
} else if (index_sq->sq.qtype == faiss::ScalarQuantizer::QT_8bit_direct_signed) { | ||
return DataFormatEnum::int8; | ||
} else { | ||
return std::nullopt; | ||
} | ||
|
@@ -806,49 +820,53 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { | |
if (data_format == DataFormatEnum::fp32) { | ||
// perform a direct reconstruction for fp32 data | ||
auto data = std::make_unique<float[]>(dim * rows); | ||
|
||
for (int64_t i = 0; i < rows; i++) { | ||
const int64_t id = ids[i]; | ||
assert(id >= 0 && id < index->ntotal); | ||
index_to_reconstruct_from->reconstruct(id, data.get() + i * dim); | ||
} | ||
|
||
return GenResultDataSet(rows, dim, std::move(data)); | ||
} else if (data_format == DataFormatEnum::fp16) { | ||
auto data = std::make_unique<knowhere::fp16[]>(dim * rows); | ||
|
||
// faiss produces fp32 data format, we need some other format. | ||
// Let's create a temporary fp32 buffer for this. | ||
auto tmp = std::make_unique<float[]>(dim); | ||
|
||
for (int64_t i = 0; i < rows; i++) { | ||
const int64_t id = ids[i]; | ||
assert(id >= 0 && id < index->ntotal); | ||
index_to_reconstruct_from->reconstruct(id, tmp.get()); | ||
|
||
if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) { | ||
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format"); | ||
} | ||
} | ||
|
||
return GenResultDataSet(rows, dim, std::move(data)); | ||
} else if (data_format == DataFormatEnum::bf16) { | ||
auto data = std::make_unique<knowhere::bf16[]>(dim * rows); | ||
|
||
// faiss produces fp32 data format, we need some other format. | ||
// Let's create a temporary fp32 buffer for this. | ||
auto tmp = std::make_unique<float[]>(dim); | ||
|
||
for (int64_t i = 0; i < rows; i++) { | ||
const int64_t id = ids[i]; | ||
assert(id >= 0 && id < index->ntotal); | ||
index_to_reconstruct_from->reconstruct(id, tmp.get()); | ||
|
||
if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) { | ||
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format"); | ||
} | ||
} | ||
|
||
return GenResultDataSet(rows, dim, std::move(data)); | ||
} else if (data_format == DataFormatEnum::int8) { | ||
auto data = std::make_unique<knowhere::int8[]>(dim * rows); | ||
// faiss produces fp32 data format, we need some other format. | ||
// Let's create a temporary fp32 buffer for this. | ||
auto tmp = std::make_unique<float[]>(dim); | ||
for (int64_t i = 0; i < rows; i++) { | ||
const int64_t id = ids[i]; | ||
assert(id >= 0 && id < index->ntotal); | ||
index_to_reconstruct_from->reconstruct(id, tmp.get()); | ||
if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) { | ||
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format"); | ||
} | ||
} | ||
return GenResultDataSet(rows, dim, std::move(data)); | ||
} else { | ||
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format"); | ||
|
@@ -1234,13 +1252,18 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { | |
// The query data is always cloned | ||
std::unique_ptr<float[]> cur_query = std::make_unique<float[]>(dim); | ||
|
||
if (data_format == DataFormatEnum::fp32) { | ||
std::copy_n(reinterpret_cast<const float*>(data) + idx * dim, dim, cur_query.get()); | ||
} else if (data_format == DataFormatEnum::fp16 || data_format == DataFormatEnum::bf16) { | ||
convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim); | ||
} else { | ||
// invalid one. Should not be triggered, bcz input parameters are validated | ||
throw; | ||
switch (data_format) { | ||
case DataFormatEnum::fp32: | ||
std::copy_n(reinterpret_cast<const float*>(data) + idx * dim, dim, cur_query.get()); | ||
break; | ||
case DataFormatEnum::fp16: | ||
case DataFormatEnum::bf16: | ||
case DataFormatEnum::int8: | ||
convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim); | ||
break; | ||
default: | ||
// invalid one. Should not be triggered, bcz input parameters are validated | ||
throw; | ||
} | ||
|
||
const bool should_use_refine = (dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr); | ||
|
@@ -1327,6 +1350,9 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { | |
} else if (data_format == DataFormatEnum::bf16) { | ||
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(dim, faiss::ScalarQuantizer::QT_bf16, | ||
hnsw_cfg.M.value()); | ||
} else if (data_format == DataFormatEnum::int8) { | ||
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>( | ||
dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please DO confirm that you want to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's the requirement from Milvus, since vespa and qdrant already support Vector_Int8 now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi Alex, I see no obvious difference between min() and lowest(), I prefer to use min() and max() in pair. |
||
} else { | ||
LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); | ||
return Status::invalid_metric_type; | ||
|
@@ -1340,6 +1366,9 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { | |
} else if (data_format == DataFormatEnum::bf16) { | ||
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_bf16, | ||
hnsw_cfg.M.value(), metric.value()); | ||
} else if (data_format == DataFormatEnum::int8) { | ||
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, | ||
hnsw_cfg.M.value(), metric.value()); | ||
} else { | ||
LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); | ||
return Status::invalid_metric_type; | ||
|
@@ -1564,10 +1593,12 @@ namespace { | |
// a supporting function | ||
expected<faiss::ScalarQuantizer::QuantizerType> | ||
get_sq_quantizer_type(const std::string& sq_type) { | ||
std::map<std::string, faiss::ScalarQuantizer::QuantizerType> sq_types = {{"sq6", faiss::ScalarQuantizer::QT_6bit}, | ||
{"sq8", faiss::ScalarQuantizer::QT_8bit}, | ||
{"fp16", faiss::ScalarQuantizer::QT_fp16}, | ||
{"bf16", faiss::ScalarQuantizer::QT_bf16}}; | ||
std::map<std::string, faiss::ScalarQuantizer::QuantizerType> sq_types = { | ||
{"sq6", faiss::ScalarQuantizer::QT_6bit}, | ||
{"sq8", faiss::ScalarQuantizer::QT_8bit}, | ||
{"fp16", faiss::ScalarQuantizer::QT_fp16}, | ||
{"bf16", faiss::ScalarQuantizer::QT_bf16}, | ||
{"int8", faiss::ScalarQuantizer::QT_8bit_direct_signed}}; | ||
|
||
// todo: tolower | ||
auto sq_type_tolower = str_to_lower(sq_type); | ||
|
@@ -1653,6 +1684,8 @@ has_lossless_quant(const expected<faiss::ScalarQuantizer::QuantizerType>& quant_ | |
return quant == faiss::ScalarQuantizer::QuantizerType::QT_fp16; | ||
case DataFormatEnum::bf16: | ||
return quant == faiss::ScalarQuantizer::QuantizerType::QT_bf16; | ||
case DataFormatEnum::int8: | ||
return quant == faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct_signed; | ||
default: | ||
return false; | ||
} | ||
|
@@ -2280,13 +2313,18 @@ KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_DEPRECATED, | |
#else | ||
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback, | ||
knowhere::feature::MMAP | knowhere::feature::MV) | ||
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplate, | ||
knowhere::feature::MMAP | knowhere::feature::MV) | ||
#endif | ||
|
||
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, | ||
knowhere::feature::MMAP) | ||
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, knowhere::feature::MMAP) | ||
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, | ||
knowhere::feature::MMAP) | ||
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, knowhere::feature::MMAP) | ||
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, | ||
knowhere::feature::MMAP) | ||
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, knowhere::feature::MMAP) | ||
|
||
} // namespace knowhere |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is better to use
std::numeric_limilts<knowhere::int8>::lowest()
hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Alex, what's the difference between min() and lowest() here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's no difference for this particular use case. But
lowest()
is better to use, because of the connotations withstd::numeric<float>::lowest()
(which is -1e+40, the least value) andstd::numeric<float>::min()
(which is 1e-40, the least representable positive value)