Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

faiss_hnsw support INT8 #991

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ enum VecType {
VECTOR_FLOAT16 = 102,
VECTOR_BFLOAT16 = 103,
VECTOR_SPARSE_FLOAT = 104,
VECTOR_INT8 = 105,
}; // keep the same value as milvus proto define

} // namespace knowhere
6 changes: 4 additions & 2 deletions include/knowhere/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ constexpr uint64_t FP16 = 1UL << 2;
constexpr uint64_t BF16 = 1UL << 3;
// vector datatype support : sparse_float32
constexpr uint64_t SPARSE_FLOAT32 = 1UL << 4;
// vector datatype support : int8
constexpr uint64_t INT8 = 1UL << 5;

// This flag indicates that there is no need to create any index structure (build stage can be skipped)
constexpr uint64_t NO_TRAIN = 1UL << 16;
Expand All @@ -45,8 +47,8 @@ constexpr uint64_t DISK = 1UL << 21;

constexpr uint64_t NONE = 0UL;

constexpr uint64_t ALL_TYPE = BINARY | FLOAT32 | FP16 | BF16 | SPARSE_FLOAT32;
constexpr uint64_t ALL_DENSE_TYPE = BINARY | FLOAT32 | FP16 | BF16;
constexpr uint64_t ALL_TYPE = BINARY | FLOAT32 | FP16 | BF16 | SPARSE_FLOAT32 | INT8;
constexpr uint64_t ALL_DENSE_TYPE = BINARY | FLOAT32 | FP16 | BF16 | INT8;
constexpr uint64_t ALL_DENSE_FLOAT_TYPE = FLOAT32 | FP16 | BF16;

constexpr uint64_t NO_TRAIN_INDEX = NO_TRAIN;
Expand Down
3 changes: 3 additions & 0 deletions include/knowhere/index/index_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class IndexFactory {
// register vector index supporting binary data type
#define KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__);
// register vector index supporting int8 data type
#define KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, int8, (features | knowhere::feature::INT8), ##__VA_ARGS__);

// register vector index supporting ALL_DENSE_FLOAT_TYPE(float32, bf16, fp16) data types
#define KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \
Expand Down
4 changes: 4 additions & 0 deletions include/knowhere/index/index_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,22 @@ static std::set<std::pair<std::string, VecType>> legal_knowhere_index = {
{IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_HNSW, VecType::VECTOR_INT8},

{IndexEnum::INDEX_HNSW_SQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_SQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_SQ, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_HNSW_SQ, VecType::VECTOR_INT8},

{IndexEnum::INDEX_HNSW_PQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_PQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_PQ, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_HNSW_PQ, VecType::VECTOR_INT8},

{IndexEnum::INDEX_HNSW_PRQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_PRQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_PRQ, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_HNSW_PRQ, VecType::VECTOR_INT8},

// diskann
{IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT},
Expand Down
11 changes: 10 additions & 1 deletion include/knowhere/operands.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ fp32_to_bits(const float& f) {

namespace knowhere {
using fp32 = float;
using int8 = int8_t;
using bin1 = uint8_t;

struct fp16 {
Expand Down Expand Up @@ -161,13 +162,16 @@ typeCheck(uint64_t features) {
if constexpr (std::is_same_v<T, fp32>) {
return (features & knowhere::feature::FLOAT32) || (features & knowhere::feature::SPARSE_FLOAT32);
}
if constexpr (std::is_same_v<T, int8>) {
return features & knowhere::feature::INT8;
}
return false;
}

template <typename InType, typename... Types>
using TypeMatch = std::bool_constant<(... | std::is_same_v<InType, Types>)>;
template <typename InType>
using KnowhereDataTypeCheck = TypeMatch<InType, bin1, fp16, fp32, bf16>;
using KnowhereDataTypeCheck = TypeMatch<InType, bin1, fp16, fp32, bf16, int8>;
template <typename InType>
using KnowhereFloatTypeCheck = TypeMatch<InType, fp16, fp32, bf16>;
template <typename InType>
Expand All @@ -187,5 +191,10 @@ template <>
struct MockData<knowhere::bf16> {
using type = knowhere::fp32;
};

template <>
struct MockData<knowhere::int8> {
using type = knowhere::fp32;
};
} // namespace knowhere
#endif /* OPERANDS_H */
2 changes: 2 additions & 0 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ GetKey(const std::string& name) {
return name + std::string("_bf16");
} else if (std::is_same_v<DataType, bin1>) {
return name + std::string("_bin1");
} else if (std::is_same_v<DataType, int8>) {
return name + std::string("_int8");
}
}

Expand Down
100 changes: 69 additions & 31 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode {
};

//
enum class DataFormatEnum { fp32, fp16, bf16 };
enum class DataFormatEnum { fp32, fp16, bf16, int8 };

template <typename T>
struct DataType2EnumHelper {};
Expand All @@ -309,14 +309,16 @@ template <>
struct DataType2EnumHelper<knowhere::bf16> {
static constexpr DataFormatEnum value = DataFormatEnum::bf16;
};
template <>
struct DataType2EnumHelper<knowhere::int8> {
static constexpr DataFormatEnum value = DataFormatEnum::int8;
};

template <typename T>
static constexpr DataFormatEnum datatype_v = DataType2EnumHelper<T>::value;

//
namespace {

//
bool
convert_rows_to_fp32(const void* const __restrict src_in, float* const __restrict dst,
const DataFormatEnum src_data_format, const size_t start_row, const size_t nrows,
Expand All @@ -326,21 +328,24 @@ convert_rows_to_fp32(const void* const __restrict src_in, float* const __restric
for (size_t i = 0; i < nrows * dim; i++) {
dst[i] = (float)(src[i + start_row * dim]);
}

return true;
} else if (src_data_format == DataFormatEnum::bf16) {
const knowhere::bf16* const src = reinterpret_cast<const knowhere::bf16*>(src_in);
for (size_t i = 0; i < nrows * dim; i++) {
dst[i] = (float)(src[i + start_row * dim]);
}

return true;
} else if (src_data_format == DataFormatEnum::fp32) {
const knowhere::fp32* const src = reinterpret_cast<const knowhere::fp32*>(src_in);
for (size_t i = 0; i < nrows * dim; i++) {
dst[i] = src[i + start_row * dim];
}

return true;
} else if (src_data_format == DataFormatEnum::int8) {
const knowhere::int8* const src = reinterpret_cast<const knowhere::int8*>(src_in);
for (size_t i = 0; i < nrows * dim; i++) {
dst[i] = (float)(src[i + start_row * dim]);
}
return true;
} else {
// unknown
Expand All @@ -357,21 +362,27 @@ convert_rows_from_fp32(const float* const __restrict src, void* const __restrict
for (size_t i = 0; i < nrows * dim; i++) {
dst[i + start_row * dim] = (knowhere::fp16)src[i];
}

return true;
} else if (dst_data_format == DataFormatEnum::bf16) {
knowhere::bf16* const dst = reinterpret_cast<knowhere::bf16*>(dst_in);
for (size_t i = 0; i < nrows * dim; i++) {
dst[i + start_row * dim] = (knowhere::bf16)src[i];
}

return true;
} else if (dst_data_format == DataFormatEnum::fp32) {
knowhere::fp32* const dst = reinterpret_cast<knowhere::fp32*>(dst_in);
for (size_t i = 0; i < nrows * dim; i++) {
dst[i + start_row * dim] = src[i];
}

return true;
} else if (dst_data_format == DataFormatEnum::int8) {
knowhere::int8* const dst = reinterpret_cast<knowhere::int8*>(dst_in);
for (size_t i = 0; i < nrows * dim; i++) {
KNOWHERE_THROW_IF_NOT_MSG(src[i] >= std::numeric_limits<knowhere::int8>::min() &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is better to use std::numeric_limilts<knowhere::int8>::lowest() here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Alex, what's the difference between min() and lowest() here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no difference for this particular use case. But lowest() is better to use, because of the connotations with std::numeric<float>::lowest() (which is -1e+40, the least value) and std::numeric<float>::min() (which is 1e-40, the least representable positive value)

src[i] <= std::numeric_limits<knowhere::int8>::max(),
"convert float to int8_t overflow");
dst[i + start_row * dim] = (knowhere::int8)src[i];
}
return true;
} else {
// unknown
Expand All @@ -388,8 +399,9 @@ convert_ds_to_float(const DataSetPtr& src, DataFormatEnum data_format) {
return ConvertFromDataTypeIfNeeded<knowhere::fp16>(src);
} else if (data_format == DataFormatEnum::bf16) {
return ConvertFromDataTypeIfNeeded<knowhere::bf16>(src);
} else if (data_format == DataFormatEnum::int8) {
return ConvertFromDataTypeIfNeeded<knowhere::int8>(src);
}

return nullptr;
}

Expand Down Expand Up @@ -451,6 +463,8 @@ get_index_data_format(const faiss::Index* index) {
return DataFormatEnum::bf16;
} else if (index_sq->sq.qtype == faiss::ScalarQuantizer::QT_fp16) {
return DataFormatEnum::fp16;
} else if (index_sq->sq.qtype == faiss::ScalarQuantizer::QT_8bit_direct_signed) {
return DataFormatEnum::int8;
} else {
return std::nullopt;
}
Expand Down Expand Up @@ -806,49 +820,53 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
if (data_format == DataFormatEnum::fp32) {
// perform a direct reconstruction for fp32 data
auto data = std::make_unique<float[]>(dim * rows);

for (int64_t i = 0; i < rows; i++) {
const int64_t id = ids[i];
assert(id >= 0 && id < index->ntotal);
index_to_reconstruct_from->reconstruct(id, data.get() + i * dim);
}

return GenResultDataSet(rows, dim, std::move(data));
} else if (data_format == DataFormatEnum::fp16) {
auto data = std::make_unique<knowhere::fp16[]>(dim * rows);

// faiss produces fp32 data format, we need some other format.
// Let's create a temporary fp32 buffer for this.
auto tmp = std::make_unique<float[]>(dim);

for (int64_t i = 0; i < rows; i++) {
const int64_t id = ids[i];
assert(id >= 0 && id < index->ntotal);
index_to_reconstruct_from->reconstruct(id, tmp.get());

if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) {
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format");
}
}

return GenResultDataSet(rows, dim, std::move(data));
} else if (data_format == DataFormatEnum::bf16) {
auto data = std::make_unique<knowhere::bf16[]>(dim * rows);

// faiss produces fp32 data format, we need some other format.
// Let's create a temporary fp32 buffer for this.
auto tmp = std::make_unique<float[]>(dim);

for (int64_t i = 0; i < rows; i++) {
const int64_t id = ids[i];
assert(id >= 0 && id < index->ntotal);
index_to_reconstruct_from->reconstruct(id, tmp.get());

if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) {
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format");
}
}

return GenResultDataSet(rows, dim, std::move(data));
} else if (data_format == DataFormatEnum::int8) {
auto data = std::make_unique<knowhere::int8[]>(dim * rows);
// faiss produces fp32 data format, we need some other format.
// Let's create a temporary fp32 buffer for this.
auto tmp = std::make_unique<float[]>(dim);
for (int64_t i = 0; i < rows; i++) {
const int64_t id = ids[i];
assert(id >= 0 && id < index->ntotal);
index_to_reconstruct_from->reconstruct(id, tmp.get());
if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) {
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format");
}
}
return GenResultDataSet(rows, dim, std::move(data));
} else {
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format");
Expand Down Expand Up @@ -1234,13 +1252,18 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
// The query data is always cloned
std::unique_ptr<float[]> cur_query = std::make_unique<float[]>(dim);

if (data_format == DataFormatEnum::fp32) {
std::copy_n(reinterpret_cast<const float*>(data) + idx * dim, dim, cur_query.get());
} else if (data_format == DataFormatEnum::fp16 || data_format == DataFormatEnum::bf16) {
convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim);
} else {
// invalid one. Should not be triggered, bcz input parameters are validated
throw;
switch (data_format) {
case DataFormatEnum::fp32:
std::copy_n(reinterpret_cast<const float*>(data) + idx * dim, dim, cur_query.get());
break;
case DataFormatEnum::fp16:
case DataFormatEnum::bf16:
case DataFormatEnum::int8:
convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim);
break;
default:
// invalid one. Should not be triggered, bcz input parameters are validated
throw;
}

const bool should_use_refine = (dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr);
Expand Down Expand Up @@ -1327,6 +1350,9 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode {
} else if (data_format == DataFormatEnum::bf16) {
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(dim, faiss::ScalarQuantizer::QT_bf16,
hnsw_cfg.M.value());
} else if (data_format == DataFormatEnum::int8) {
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(
dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please DO confirm that you want to use QT_8bit_direct_signed here, because the use case is not clear to me. Basically, I can imagine a use case that works with the input data of [0..255] range (QT_8bit_direct), or the traditional QT_8bit that remaps input float values into [0..255] range, but what is the use case for the input data of [-128..127] range? Or is it just the requirement from Milvus?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the requirement from Milvus, since vespa and qdrant already support Vector_Int8 now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Alex, I see no obvious difference between min() and lowest(), I prefer to use min() and max() in pair.

} else {
LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
Expand All @@ -1340,6 +1366,9 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode {
} else if (data_format == DataFormatEnum::bf16) {
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_bf16,
hnsw_cfg.M.value(), metric.value());
} else if (data_format == DataFormatEnum::int8) {
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_8bit_direct_signed,
hnsw_cfg.M.value(), metric.value());
} else {
LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
Expand Down Expand Up @@ -1564,10 +1593,12 @@ namespace {
// a supporting function
expected<faiss::ScalarQuantizer::QuantizerType>
get_sq_quantizer_type(const std::string& sq_type) {
std::map<std::string, faiss::ScalarQuantizer::QuantizerType> sq_types = {{"sq6", faiss::ScalarQuantizer::QT_6bit},
{"sq8", faiss::ScalarQuantizer::QT_8bit},
{"fp16", faiss::ScalarQuantizer::QT_fp16},
{"bf16", faiss::ScalarQuantizer::QT_bf16}};
std::map<std::string, faiss::ScalarQuantizer::QuantizerType> sq_types = {
{"sq6", faiss::ScalarQuantizer::QT_6bit},
{"sq8", faiss::ScalarQuantizer::QT_8bit},
{"fp16", faiss::ScalarQuantizer::QT_fp16},
{"bf16", faiss::ScalarQuantizer::QT_bf16},
{"int8", faiss::ScalarQuantizer::QT_8bit_direct_signed}};

// todo: tolower
auto sq_type_tolower = str_to_lower(sq_type);
Expand Down Expand Up @@ -1653,6 +1684,8 @@ has_lossless_quant(const expected<faiss::ScalarQuantizer::QuantizerType>& quant_
return quant == faiss::ScalarQuantizer::QuantizerType::QT_fp16;
case DataFormatEnum::bf16:
return quant == faiss::ScalarQuantizer::QuantizerType::QT_bf16;
case DataFormatEnum::int8:
return quant == faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct_signed;
default:
return false;
}
Expand Down Expand Up @@ -2280,13 +2313,18 @@ KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_DEPRECATED,
#else
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
#endif

KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate,
knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate,
knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate,
knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, knowhere::feature::MMAP)

} // namespace knowhere
6 changes: 6 additions & 0 deletions src/index/index_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ template knowhere::expected<knowhere::Index<knowhere::IndexNode>>
knowhere::IndexFactory::Create<knowhere::fp16>(const std::string&, const int32_t&, const Object&);
template knowhere::expected<knowhere::Index<knowhere::IndexNode>>
knowhere::IndexFactory::Create<knowhere::bf16>(const std::string&, const int32_t&, const Object&);
template knowhere::expected<knowhere::Index<knowhere::IndexNode>>
knowhere::IndexFactory::Create<knowhere::int8>(const std::string&, const int32_t&, const Object&);
template const knowhere::IndexFactory&
knowhere::IndexFactory::Register<knowhere::fp32>(
const std::string&, std::function<knowhere::Index<knowhere::IndexNode>(const int32_t&, const Object&)>,
Expand All @@ -153,3 +155,7 @@ template const knowhere::IndexFactory&
knowhere::IndexFactory::Register<knowhere::bf16>(
const std::string&, std::function<knowhere::Index<knowhere::IndexNode>(const int32_t&, const Object&)>,
const uint64_t);
template const knowhere::IndexFactory&
knowhere::IndexFactory::Register<knowhere::int8>(
const std::string&, std::function<knowhere::Index<knowhere::IndexNode>(const int32_t&, const Object&)>,
const uint64_t);
1 change: 1 addition & 0 deletions src/index/index_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,6 @@ template class IndexStaticFaced<knowhere::fp32>;
template class IndexStaticFaced<knowhere::fp16>;
template class IndexStaticFaced<knowhere::bf16>;
template class IndexStaticFaced<knowhere::bin1>;
template class IndexStaticFaced<knowhere::int8>;

} // namespace knowhere
Loading
Loading