Skip to content

Commit

Permalink
Changing handle_t to device_resources everywhere (#1140)
Browse files Browse the repository at this point in the history
This is the step 1 of a two-step change to eventually have all of our public API functions use `raft::resources` instead of `raft::handle_t`. Step 1 is to use `device_resources` everywhere. This is not a breaking change since `handle_t` extends `device_resources`. Step 2 is to scrape through and use `raft::resources`, which will require adding explicit includes for each resources and using accessor functions everywhere rather than methods on the `device_resources`. No immediate rush for the second step but I'd like to at least have the documentation adjusted to use our new nomenclature everwhere.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Frederickson (https://github.com/benfred)

URL: #1140
  • Loading branch information
cjnolet authored Jan 26, 2023
1 parent 92820d5 commit 79e1ce8
Show file tree
Hide file tree
Showing 399 changed files with 2,403 additions and 1,608 deletions.
6 changes: 3 additions & 3 deletions cpp/bench/common/benchmark.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,7 +20,7 @@

#include <raft/core/detail/macros.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/cudart_utils.h>
#include <raft/interruptible.hpp>
#include <raft/random/make_blobs.cuh>
Expand Down Expand Up @@ -110,7 +110,7 @@ class fixture {
rmm::device_buffer scratch_buf_;

public:
raft::handle_t handle;
raft::device_resources handle;
rmm::cuda_stream_view stream;

fixture() : stream{handle.get_stream()}
Expand Down
6 changes: 3 additions & 3 deletions cpp/bench/distance/kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@

#include <common/benchmark.hpp>
#include <memory>
#include <raft/core/handle.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/kernels.cuh>
#include <raft/random/rng.cuh>
Expand Down Expand Up @@ -77,7 +77,7 @@ struct GramMatrix : public fixture {
}

private:
const raft::handle_t handle;
const raft::device_resources handle;
std::unique_ptr<GramMatrixBase<T>> kernel;
GramTestParams params;

Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <common/benchmark.hpp>

#include <raft/core/handle.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/random/rng.cuh>
#include <raft/sparse/detail/utils.h>
#include <raft/util/cudart_utils.hpp>
Expand Down Expand Up @@ -51,7 +51,7 @@ struct selection : public fixture {

void run_benchmark(::benchmark::State& state) override // NOLINT
{
handle_t handle{stream};
device_resources handle{stream};
using_pool_memory_res res;
try {
std::ostringstream label_stream;
Expand Down
16 changes: 8 additions & 8 deletions cpp/bench/neighbors/knn.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -148,15 +148,15 @@ struct ivf_flat_knn {
raft::neighbors::ivf_flat::search_params search_params;
params ps;

ivf_flat_knn(const raft::handle_t& handle, const params& ps, const ValT* data) : ps(ps)
ivf_flat_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index.emplace(raft::neighbors::ivf_flat::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
}

void search(const raft::handle_t& handle,
void search(const raft::device_resources& handle,
const ValT* search_items,
dist_t* out_dists,
IdxT* out_idxs)
Expand All @@ -176,15 +176,15 @@ struct ivf_pq_knn {
raft::neighbors::ivf_pq::search_params search_params;
params ps;

ivf_pq_knn(const raft::handle_t& handle, const params& ps, const ValT* data) : ps(ps)
ivf_pq_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index.emplace(raft::neighbors::ivf_pq::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
}

void search(const raft::handle_t& handle,
void search(const raft::device_resources& handle,
const ValT* search_items,
dist_t* out_dists,
IdxT* out_idxs)
Expand All @@ -202,12 +202,12 @@ struct brute_force_knn {
ValT* index;
params ps;

brute_force_knn(const raft::handle_t& handle, const params& ps, const ValT* data)
brute_force_knn(const raft::device_resources& handle, const params& ps, const ValT* data)
: index(const_cast<ValT*>(data)), ps(ps)
{
}

void search(const raft::handle_t& handle,
void search(const raft::device_resources& handle,
const ValT* search_items,
dist_t* out_dists,
IdxT* out_idxs)
Expand Down Expand Up @@ -287,7 +287,7 @@ struct knn : public fixture {
std::ostringstream label_stream;
label_stream << params_ << "#" << strategy_ << "#" << scope_;
state.SetLabel(label_stream.str());
raft::handle_t handle(stream);
raft::device_resources handle(stream);
std::optional<ImplT> index;

if (scope_ == Scope::SEARCH) { // also implies TransferStrategy::NO_COPY
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/neighbors/refine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <common/benchmark.hpp>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/detail/refine.cuh>
#include <raft/neighbors/refine.cuh>
Expand Down Expand Up @@ -94,7 +94,7 @@ class RefineAnn : public fixture {
}

private:
raft::handle_t handle_;
raft::device_resources handle_;
RefineHelper<DataT, DistanceT, IdxT> data;
};

Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/random/permute.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,7 +50,7 @@ struct permute : public fixture {
}

private:
raft::handle_t handle;
raft::device_resources handle;
permute_inputs params;
rmm::device_uvector<T> out, in;
rmm::device_uvector<int> perms;
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/sparse/convert_csr.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -107,7 +107,7 @@ struct bench_base : public fixture {
}

protected:
raft::handle_t handle;
raft::device_resources handle;
bench_param<index_t> params;
rmm::device_uvector<bool> adj;
rmm::device_uvector<index_t> row_ind;
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/cluster/detail/agglomerative.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,7 @@

#pragma once

#include <raft/core/handle.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -100,7 +100,7 @@ class UnionFind {
* @param[out] out_size cluster sizes of output
*/
template <typename value_idx, typename value_t>
void build_dendrogram_host(const handle_t& handle,
void build_dendrogram_host(raft::device_resources const& handle,
const value_idx* rows,
const value_idx* cols,
const value_t* data,
Expand Down Expand Up @@ -236,7 +236,7 @@ struct init_label_roots {
* @param n_leaves
*/
template <typename value_idx, int tpb = 256>
void extract_flattened_clusters(const raft::handle_t& handle,
void extract_flattened_clusters(raft::device_resources const& handle,
value_idx* labels,
const value_idx* children,
size_t n_clusters,
Expand Down
103 changes: 98 additions & 5 deletions cpp/include/raft/cluster/detail/connectivities.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,14 +16,15 @@

#pragma once

#include <raft/core/handle.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <raft/linalg/unary_op.cuh>
#include <rmm/device_uvector.hpp>

#include <raft/cluster/single_linkage_types.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/coo.hpp>
Expand All @@ -39,7 +40,7 @@ namespace raft::cluster::detail {

template <raft::cluster::LinkageDistance dist_type, typename value_idx, typename value_t>
struct distance_graph_impl {
void run(const raft::handle_t& handle,
void run(raft::device_resources const& handle,
const value_t* X,
size_t m,
size_t n,
Expand All @@ -57,7 +58,7 @@ struct distance_graph_impl {
*/
template <typename value_idx, typename value_t>
struct distance_graph_impl<raft::cluster::LinkageDistance::KNN_GRAPH, value_idx, value_t> {
void run(const raft::handle_t& handle,
void run(raft::device_resources const& handle,
const value_t* X,
size_t m,
size_t n,
Expand Down Expand Up @@ -103,6 +104,98 @@ struct distance_graph_impl<raft::cluster::LinkageDistance::KNN_GRAPH, value_idx,
}
};

template <typename value_idx>
__global__ void fill_indices2(value_idx* indices, size_t m, size_t nnz)
{
value_idx tid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (tid >= nnz) return;
value_idx v = tid % m;
indices[tid] = v;
}

/**
* Compute connected CSR of pairwise distances
* @tparam value_idx
* @tparam value_t
* @param handle
* @param X
* @param m
* @param n
* @param metric
* @param[out] indptr
* @param[out] indices
* @param[out] data
*/
template <typename value_idx, typename value_t>
void pairwise_distances(const raft::device_resources& handle,
const value_t* X,
size_t m,
size_t n,
raft::distance::DistanceType metric,
value_idx* indptr,
value_idx* indices,
value_t* data)
{
auto stream = handle.get_stream();
auto exec_policy = handle.get_thrust_policy();

value_idx nnz = m * m;

value_idx blocks = raft::ceildiv(nnz, (value_idx)256);
fill_indices2<value_idx><<<blocks, 256, 0, stream>>>(indices, m, nnz);

thrust::sequence(exec_policy, indptr, indptr + m, 0, (int)m);

raft::update_device(indptr + m, &nnz, 1, stream);

// TODO: It would ultimately be nice if the MST could accept
// dense inputs directly so we don't need to double the memory
// usage to hand it a sparse array here.
distance::pairwise_distance<value_t, value_idx>(handle, X, X, data, m, m, n, metric);
// self-loops get max distance
auto transform_in =
thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), data));

thrust::transform(exec_policy,
transform_in,
transform_in + nnz,
data,
[=] __device__(const thrust::tuple<value_idx, value_t>& tup) {
value_idx idx = thrust::get<0>(tup);
bool self_loop = idx % m == idx / m;
return (self_loop * std::numeric_limits<value_t>::max()) +
(!self_loop * thrust::get<1>(tup));
});
}

/**
* Connectivities specialization for pairwise distances
* @tparam value_idx
* @tparam value_t
*/
template <typename value_idx, typename value_t>
struct distance_graph_impl<raft::cluster::LinkageDistance::PAIRWISE, value_idx, value_t> {
void run(const raft::device_resources& handle,
const value_t* X,
size_t m,
size_t n,
raft::distance::DistanceType metric,
rmm::device_uvector<value_idx>& indptr,
rmm::device_uvector<value_idx>& indices,
rmm::device_uvector<value_t>& data,
int c)
{
auto stream = handle.get_stream();

size_t nnz = m * m;

indices.resize(nnz, stream);
data.resize(nnz, stream);

pairwise_distances(handle, X, m, n, metric, indptr.data(), indices.data(), data.data());
}
};

/**
* Returns a CSR connectivities graph based on the given linkage distance.
* @tparam value_idx
Expand All @@ -120,7 +213,7 @@ struct distance_graph_impl<raft::cluster::LinkageDistance::KNN_GRAPH, value_idx,
* which will guarantee k <= log(n) + c
*/
template <typename value_idx, typename value_t, raft::cluster::LinkageDistance dist_type>
void get_distance_graph(const raft::handle_t& handle,
void get_distance_graph(raft::device_resources const& handle,
const value_t* X,
size_t m,
size_t n,
Expand Down
Loading

0 comments on commit 79e1ce8

Please sign in to comment.