Skip to content

Commit

Permalink
Update raft::sparse::distance::pairwise_distance to new API (#5428)
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala authored Jul 5, 2023
1 parent 00bdad3 commit a84df5b
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions cpp/src/metrics/pairwise_distance.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,7 +31,6 @@
#include <raft/core/handle.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/specializations.cuh>
#include <raft/sparse/distance/common.h>
#include <raft/sparse/distance/distance.cuh>

namespace ML {
Expand Down Expand Up @@ -164,23 +163,18 @@ void pairwiseDistance_sparse(const raft::handle_t& handle,
raft::distance::DistanceType metric,
float metric_arg)
{
raft::sparse::distance::distances_config_t<value_idx, value_t> dist_config(handle);
auto out = raft::make_device_matrix_view<value_t, value_idx>(dist, y_nrows, x_nrows);

dist_config.b_nrows = x_nrows;
dist_config.b_ncols = n_cols;
dist_config.b_nnz = x_nnz;
dist_config.b_indptr = x_indptr;
dist_config.b_indices = x_indices;
dist_config.b_data = x;
auto x_structure = raft::make_device_compressed_structure_view<value_idx, value_idx, value_idx>(
x_indptr, x_indices, x_nrows, n_cols, x_nnz);
auto x_csr_view = raft::make_device_csr_matrix_view<const value_t>(x, x_structure);

dist_config.a_nrows = y_nrows;
dist_config.a_ncols = n_cols;
dist_config.a_nnz = y_nnz;
dist_config.a_indptr = y_indptr;
dist_config.a_indices = y_indices;
dist_config.a_data = y;
auto y_structure = raft::make_device_compressed_structure_view<value_idx, value_idx, value_idx>(
y_indptr, y_indices, y_nrows, n_cols, y_nnz);
auto y_csr_view = raft::make_device_csr_matrix_view<const value_t>(y, y_structure);

raft::sparse::distance::pairwiseDistance(dist, dist_config, metric, metric_arg);
raft::sparse::distance::pairwise_distance(
handle, y_csr_view, x_csr_view, out, metric, metric_arg);
}

void pairwiseDistance_sparse(const raft::handle_t& handle,
Expand Down

0 comments on commit a84df5b

Please sign in to comment.