Skip to content

Commit

Permalink
enhance: remove extra memory copy of bf16/bf16 brustforce search
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
  • Loading branch information
cqy123456 committed Dec 26, 2024
1 parent 85f462b commit 11335af
Showing 1 changed file with 6 additions and 24 deletions.
30 changes: 6 additions & 24 deletions internal/core/src/query/SearchBruteForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,6 @@ PrepareBFDataSet(const dataset::SearchDataset& query_ds,
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
} else if (data_type == DataType::VECTOR_BFLOAT16) {
//todo: if knowhere support real fp16/bf16 bf, remove convert
base_dataset =
knowhere::ConvertFromDataTypeIfNeeded<bfloat16>(base_dataset);
query_dataset =
knowhere::ConvertFromDataTypeIfNeeded<bfloat16>(query_dataset);
} else if (data_type == DataType::VECTOR_FLOAT16) {
//todo: if knowhere support real fp16/bf16 bf, remove convert
base_dataset =
knowhere::ConvertFromDataTypeIfNeeded<float16>(base_dataset);
query_dataset =
knowhere::ConvertFromDataTypeIfNeeded<float16>(query_dataset);
}
base_dataset->SetTensorBeginId(raw_ds.begin_id);
return std::make_pair(query_dataset, base_dataset);
Expand Down Expand Up @@ -133,12 +121,10 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
res = knowhere::BruteForce::RangeSearch<float>(
base_dataset, query_dataset, search_cfg, bitset);
} else if (data_type == DataType::VECTOR_FLOAT16) {
//todo: if knowhere support real fp16/bf16 bf, change it
res = knowhere::BruteForce::RangeSearch<float>(
res = knowhere::BruteForce::RangeSearch<float16>(
base_dataset, query_dataset, search_cfg, bitset);
} else if (data_type == DataType::VECTOR_BFLOAT16) {
//todo: if knowhere support real fp16/bf16 bf, change it
res = knowhere::BruteForce::RangeSearch<float>(
res = knowhere::BruteForce::RangeSearch<bfloat16>(
base_dataset, query_dataset, search_cfg, bitset);
} else if (data_type == DataType::VECTOR_BINARY) {
res = knowhere::BruteForce::RangeSearch<uint8_t>(
Expand Down Expand Up @@ -178,17 +164,15 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
search_cfg,
bitset);
} else if (data_type == DataType::VECTOR_FLOAT16) {
//todo: if knowhere support real fp16/bf16 bf, change it
stat = knowhere::BruteForce::SearchWithBuf<float>(
stat = knowhere::BruteForce::SearchWithBuf<float16>(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_distances().data(),
search_cfg,
bitset);
} else if (data_type == DataType::VECTOR_BFLOAT16) {
//todo: if knowhere support real fp16/bf16 bf, change it
stat = knowhere::BruteForce::SearchWithBuf<float>(
stat = knowhere::BruteForce::SearchWithBuf<bfloat16>(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
Expand Down Expand Up @@ -238,13 +222,11 @@ DispatchBruteForceIteratorByDataType(const knowhere::DataSetPtr& base_dataset,
base_dataset, query_dataset, config, bitset);
break;
case DataType::VECTOR_FLOAT16:
//todo: if knowhere support real fp16/bf16 bf, change it
return knowhere::BruteForce::AnnIterator<float>(
return knowhere::BruteForce::AnnIterator<float16>(
base_dataset, query_dataset, config, bitset);
break;
case DataType::VECTOR_BFLOAT16:
//todo: if knowhere support real fp16/bf16 bf, change it
return knowhere::BruteForce::AnnIterator<float>(
return knowhere::BruteForce::AnnIterator<bfloat16>(
base_dataset, query_dataset, config, bitset);
break;
case DataType::VECTOR_SPARSE_FLOAT:
Expand Down

0 comments on commit 11335af

Please sign in to comment.