Skip to content

Commit

Permalink
select_k: Replace specialization by split header
Browse files Browse the repository at this point in the history
  • Loading branch information
ahendriksen committed Apr 19, 2023
1 parent 722389d commit bc7bea4
Show file tree
Hide file tree
Showing 14 changed files with 300 additions and 173 deletions.
6 changes: 6 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,12 @@ if(RAFT_COMPILE_LIBRARY)
src/cluster/update_centroids_double.cu
src/cluster/cluster_cost_float.cu
src/cluster/cluster_cost_double.cu
src/matrix/detail/select_k_double_int64_t.cu
src/matrix/detail/select_k_double_uint32_t.cu
src/matrix/detail/select_k_float_int64_t.cu
src/matrix/detail/select_k_float_uint32_t.cu
src/matrix/detail/select_k_half_int64_t.cu
src/matrix/detail/select_k_half_uint32_t.cu
src/neighbors/refine_d_int64_t_float.cu
src/neighbors/refine_d_int64_t_int8_t.cu
src/neighbors/refine_d_int64_t_uint8_t.cu
Expand Down
65 changes: 65 additions & 0 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstdint> // uint32_t
#include <cuda_fp16.h> // __half
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view
#include <rmm/mr/device/device_memory_resource.hpp> // rmm::mr::device_memory_resource

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

namespace raft::matrix::detail {

template <typename T, typename IdxT>
void select_k(const T* in_val,
const IdxT* in_idx,
size_t batch_size,
size_t len,
int k,
T* out_val,
IdxT* out_idx,
bool select_min,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
extern template void raft::matrix::detail::select_k(const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::cuda_stream_view stream, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
instantiate_raft_matrix_detail_select_k(float, uint32_t);
// We did not have these two for double before, but there are tests for them. We
// therefore include them here.
instantiate_raft_matrix_detail_select_k(double, int64_t);
instantiate_raft_matrix_detail_select_k(double, uint32_t);

#undef instantiate_raft_matrix_detail_select_k
25 changes: 25 additions & 0 deletions cpp/include/raft/matrix/detail/select_k.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY)
#include "select_k-inl.cuh"
#endif

#ifdef RAFT_COMPILED
#include "select_k-ext.cuh"
#endif
34 changes: 5 additions & 29 deletions cpp/include/raft/matrix/specializations/detail/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,8 @@

#pragma once

#include <raft/matrix/detail/select_k.cuh>

#include <cuda_fp16.h>

namespace raft::matrix::detail {

#define RAFT_INST(T, IdxT) \
extern template void select_k<T, IdxT>(const T*, \
const IdxT*, \
size_t, \
size_t, \
int, \
T*, \
IdxT*, \
bool, \
rmm::cuda_stream_view, \
rmm::mr::device_memory_resource*);

// Commonly used types
RAFT_INST(float, int64_t);
RAFT_INST(half, int64_t);

// These instances are used in the ivf_pq::search parameterized by the internal_distance_dtype
RAFT_INST(float, uint32_t);
RAFT_INST(half, uint32_t);

#undef RAFT_INST

} // namespace raft::matrix::detail
#pragma message( \
__FILE__ \
" is deprecated and will be removed." \
" Including specializations is not necessary any more." \
" For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html")
33 changes: 33 additions & 0 deletions cpp/src/matrix/detail/select_k_double_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/matrix/detail/select_k-inl.cuh>

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
template void raft::matrix::detail::select_k(const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::cuda_stream_view stream, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_matrix_detail_select_k(double, int64_t);

#undef instantiate_raft_matrix_detail_select_k
34 changes: 34 additions & 0 deletions cpp/src/matrix/detail/select_k_double_uint32_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cstdint> // uint32_t
#include <raft/matrix/detail/select_k-inl.cuh>

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
template void raft::matrix::detail::select_k(const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::cuda_stream_view stream, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_matrix_detail_select_k(double, uint32_t);

#undef instantiate_raft_matrix_detail_select_k
33 changes: 33 additions & 0 deletions cpp/src/matrix/detail/select_k_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/matrix/detail/select_k-inl.cuh>

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
template void raft::matrix::detail::select_k(const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::cuda_stream_view stream, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_matrix_detail_select_k(float, int64_t);

#undef instantiate_raft_matrix_detail_select_k
33 changes: 33 additions & 0 deletions cpp/src/matrix/detail/select_k_float_uint32_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/matrix/detail/select_k-inl.cuh>

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
template void raft::matrix::detail::select_k(const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::cuda_stream_view stream, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_matrix_detail_select_k(float, uint32_t);

#undef instantiate_raft_matrix_detail_select_k
33 changes: 33 additions & 0 deletions cpp/src/matrix/detail/select_k_half_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/matrix/detail/select_k-inl.cuh>

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
template void raft::matrix::detail::select_k(const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::cuda_stream_view stream, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_matrix_detail_select_k(__half, int64_t);

#undef instantiate_raft_matrix_detail_select_k
33 changes: 33 additions & 0 deletions cpp/src/matrix/detail/select_k_half_uint32_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/matrix/detail/select_k-inl.cuh>

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
template void raft::matrix::detail::select_k(const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::cuda_stream_view stream, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_matrix_detail_select_k(__half, uint32_t);

#undef instantiate_raft_matrix_detail_select_k
36 changes: 0 additions & 36 deletions cpp/src/matrix/specializations/detail/select_k_float_int64_t.cu

This file was deleted.

Loading

0 comments on commit bc7bea4

Please sign in to comment.