Skip to content

Commit

Permalink
re-add MetricProcessor code
Browse files Browse the repository at this point in the history
We don't currently have cosine distance for ivf-pq (see rapidsai/cuvs#346)
and we also don't have correlation distance support at all. re-add the metricprocessor
code to handle this
  • Loading branch information
benfred committed Sep 27, 2024
1 parent 590865d commit c7d1b0e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
2 changes: 2 additions & 0 deletions cpp/include/cuml/neighbors/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/ball_cover_types.hpp>
#include <raft/spatial/knn/detail/processing.hpp> // MetricProcessor

#include <cuvs/neighbors/ivf_flat.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>
Expand Down Expand Up @@ -81,6 +82,7 @@ struct knnIndex {
raft::distance::DistanceType metric;
float metricArg;
int nprobe;
std::unique_ptr<raft::spatial::knn::MetricProcessor<float>> metric_processor;

std::unique_ptr<cuvs::neighbors::ivf_flat::index<float, int64_t>> ivf_flat;
std::unique_ptr<cuvs::neighbors::ivf_pq::index<int64_t>> ivf_pq;
Expand Down
37 changes: 36 additions & 1 deletion cpp/src/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,18 @@ void approx_knn_build_index(raft::handle_t& handle,

auto ivf_ft_pams = dynamic_cast<IVFFlatParam*>(params);
auto ivf_pq_pams = dynamic_cast<IVFPQParam*>(params);
auto index_view = raft::make_device_matrix_view<const float, int64_t>(index_array, n, D);

index->metric_processor = raft::spatial::knn::create_processor<float>(
metric, n, D, 0, false, raft::resource::get_cuda_stream(handle));
// For cosine/correlation distance, the metric processor translates distance
// to inner product via pre/post processing - pass the translated metric to
// ANN index
if (metric == raft::distance::DistanceType::CosineExpanded ||
metric == raft::distance::DistanceType::CorrelationExpanded) {
metric = index->metric = raft::distance::DistanceType::InnerProduct;
}
index->metric_processor->preprocess(index_array);
auto index_view = raft::make_device_matrix_view<const float, int64_t>(index_array, n, D);

if (ivf_ft_pams) {
index->nprobe = ivf_ft_pams->nprobe;
Expand All @@ -216,6 +227,8 @@ void approx_knn_build_index(raft::handle_t& handle,
} else {
RAFT_FAIL("Unrecognized index type.");
}

index->metric_processor->revert(index_array);
}

void approx_knn_search(raft::handle_t& handle,
Expand All @@ -226,6 +239,9 @@ void approx_knn_search(raft::handle_t& handle,
float* query_array,
int n)
{
index->metric_processor->preprocess(query_array);
index->metric_processor->set_num_queries(k);

auto indices_view = raft::make_device_matrix_view<int64_t, int64_t>(indices, n, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, n, k);

Expand All @@ -248,6 +264,25 @@ void approx_knn_search(raft::handle_t& handle,
} else {
RAFT_FAIL("The model is not trained");
}

index->metric_processor->revert(query_array);

// perform post-processing to show the real distances
if (index->metric == raft::distance::DistanceType::L2SqrtExpanded ||
index->metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
index->metric == raft::distance::DistanceType::LpUnexpanded) {
/**
* post-processing
*/
float p = 0.5; // standard l2
if (index->metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / index->metricArg;
raft::linalg::unaryOp<float>(distances,
distances,
n * k,
raft::pow_const_op<float>(p),
raft::resource::get_cuda_stream(handle));
}
index->metric_processor->postprocess(distances);
}

void knn_classify(raft::handle_t& handle,
Expand Down

0 comments on commit c7d1b0e

Please sign in to comment.