Skip to content

Commit

Permalink
Add python bindings for matrix::select_k (#1422)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1422
  • Loading branch information
benfred authored Apr 17, 2023
1 parent ba207a0 commit 574f8f8
Show file tree
Hide file tree
Showing 17 changed files with 389 additions and 22 deletions.
7 changes: 3 additions & 4 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,11 @@ option(RAFT_COMPILE_LIBRARY "Enable building raft shared library instantiations"
${RAFT_COMPILE_LIBRARY_DEFAULT}
)


# Needed because GoogleBenchmark changes the state of FindThreads.cmake, causing subsequent runs
# to have different values for the `Threads::Threads` target. Setting this flag ensures
# Needed because GoogleBenchmark changes the state of FindThreads.cmake, causing subsequent runs to
# have different values for the `Threads::Threads` target. Setting this flag ensures
# `Threads::Threads` is the same value across all builds so that cache hits occur
set(THREADS_PREFER_PTHREAD_FLAG ON)


include(CMakeDependentOption)
# cmake_dependent_option( RAFT_USE_FAISS_STATIC "Build and statically link the FAISS library for
# nearest neighbors search on GPU" ON RAFT_COMPILE_LIBRARY OFF )
Expand Down Expand Up @@ -329,6 +327,7 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/specializations/fused_l2_nn_double_int64.cu
src/distance/specializations/fused_l2_nn_float_int.cu
src/distance/specializations/fused_l2_nn_float_int64.cu
src/matrix/select_k_float_int64_t.cu
src/matrix/specializations/detail/select_k_float_uint32_t.cu
src/matrix/specializations/detail/select_k_float_int64_t.cu
src/matrix/specializations/detail/select_k_half_uint32_t.cu
Expand Down
20 changes: 10 additions & 10 deletions cpp/include/raft/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ namespace raft::matrix {
* @code{.cpp}
* using namespace raft;
* // get a 2D row-major array of values to search through
* auto in_values = {... input device_matrix_view<const float, size_t, row_major> ...}
* auto in_values = {... input device_matrix_view<const float, int64_t, row_major> ...}
* // prepare output arrays
* auto out_extents = make_extents<size_t>(in_values.extent(0), k);
* auto out_extents = make_extents<int64_t>(in_values.extent(0), k);
* auto out_values = make_device_mdarray<float>(handle, out_extents);
* auto out_indices = make_device_mdarray<size_t>(handle, out_extents);
* auto out_indices = make_device_mdarray<int64_t>(handle, out_extents);
* // search `k` smallest values in each row
* matrix::select_k<float, size_t>(
* matrix::select_k<float, int64_t>(
* handle, in_values, std::nullopt, out_values.view(), out_indices.view(), true);
* @endcode
*
Expand Down Expand Up @@ -76,13 +76,13 @@ namespace raft::matrix {
*/
template <typename T, typename IdxT>
void select_k(const device_resources& handle,
raft::device_matrix_view<const T, size_t, row_major> in_val,
std::optional<raft::device_matrix_view<const IdxT, size_t, row_major>> in_idx,
raft::device_matrix_view<T, size_t, row_major> out_val,
raft::device_matrix_view<IdxT, size_t, row_major> out_idx,
raft::device_matrix_view<const T, int64_t, row_major> in_val,
std::optional<raft::device_matrix_view<const IdxT, int64_t, row_major>> in_idx,
raft::device_matrix_view<T, int64_t, row_major> out_val,
raft::device_matrix_view<IdxT, int64_t, row_major> out_idx,
bool select_min)
{
RAFT_EXPECTS(out_val.extent(1) <= size_t(std::numeric_limits<int>::max()),
RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits<int>::max()),
"output k must fit the int type.");
auto batch_size = in_val.extent(0);
auto len = in_val.extent(1);
Expand All @@ -93,7 +93,7 @@ void select_k(const device_resources& handle,
RAFT_EXPECTS(batch_size == in_idx->extent(0), "batch sizes must be equal");
RAFT_EXPECTS(len == in_idx->extent(1), "value and index input lengths must be equal");
}
RAFT_EXPECTS(size_t(k) == out_idx.extent(1), "value and index output lengths must be equal");
RAFT_EXPECTS(int64_t(k) == out_idx.extent(1), "value and index output lengths must be equal");
return detail::select_k<T, IdxT>(in_val.data_handle(),
in_idx.has_value() ? in_idx->data_handle() : nullptr,
batch_size,
Expand Down
32 changes: 32 additions & 0 deletions cpp/include/raft_runtime/matrix/select_k.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>

#include <optional>

namespace raft::runtime::matrix {
void select_k(const device_resources& handle,
raft::device_matrix_view<const float, int64_t, row_major> in_val,
std::optional<raft::device_matrix_view<const int64_t, int64_t, row_major>> in_idx,
raft::device_matrix_view<float, int64_t, row_major> out_val,
raft::device_matrix_view<int64_t, int64_t, row_major> out_idx,
bool select_min);

} // namespace raft::runtime::matrix
13 changes: 7 additions & 6 deletions cpp/internal/raft_internal/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,13 @@ void select_k_impl(const device_resources& handle,
auto stream = handle.get_stream();
switch (algo) {
case Algo::kPublicApi: {
auto in_extent = make_extents<size_t>(batch_size, len);
auto out_extent = make_extents<size_t>(batch_size, k);
auto in_span = make_mdspan<const T, size_t, row_major, false, true>(in, in_extent);
auto in_idx_span = make_mdspan<const IdxT, size_t, row_major, false, true>(in_idx, in_extent);
auto out_span = make_mdspan<T, size_t, row_major, false, true>(out, out_extent);
auto out_idx_span = make_mdspan<IdxT, size_t, row_major, false, true>(out_idx, out_extent);
auto in_extent = make_extents<int64_t>(batch_size, len);
auto out_extent = make_extents<int64_t>(batch_size, k);
auto in_span = make_mdspan<const T, int64_t, row_major, false, true>(in, in_extent);
auto in_idx_span =
make_mdspan<const IdxT, int64_t, row_major, false, true>(in_idx, in_extent);
auto out_span = make_mdspan<T, int64_t, row_major, false, true>(out, out_extent);
auto out_idx_span = make_mdspan<IdxT, int64_t, row_major, false, true>(out_idx, out_extent);
if (in_idx == nullptr) {
// NB: std::nullopt prevents automatic inference of the template parameters.
return matrix::select_k<T, IdxT>(
Expand Down
37 changes: 37 additions & 0 deletions cpp/src/matrix/select_k_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/matrix/select_k.cuh>
#include <raft/matrix/specializations.cuh>

#include <raft_runtime/matrix/select_k.hpp>

#include <vector>

namespace raft::runtime::matrix {

void select_k(const device_resources& handle,
raft::device_matrix_view<const float, int64_t, row_major> in_val,
std::optional<raft::device_matrix_view<const int64_t, int64_t, row_major>> in_idx,
raft::device_matrix_view<float, int64_t, row_major> out_val,
raft::device_matrix_view<int64_t, int64_t, row_major> out_idx,
bool select_min)
{
raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, select_min);
}
} // namespace raft::runtime::matrix
1 change: 1 addition & 0 deletions python/pylibraft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ rapids_cython_init()

add_subdirectory(pylibraft/common)
add_subdirectory(pylibraft/distance)
add_subdirectory(pylibraft/matrix)
add_subdirectory(pylibraft/neighbors)
add_subdirectory(pylibraft/random)
add_subdirectory(pylibraft/cluster)
Expand Down
24 changes: 24 additions & 0 deletions python/pylibraft/pylibraft/matrix/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# =============================================================================
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================

# Set the list of Cython files to build
set(cython_sources select_k.pyx)
set(linked_libraries raft::raft raft::compiled)

# Build all of the Cython targets
rapids_cython_create_modules(
CXX
SOURCE_FILES "${cython_sources}"
LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX matrix_
)
14 changes: 14 additions & 0 deletions python/pylibraft/pylibraft/matrix/__init__.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
18 changes: 18 additions & 0 deletions python/pylibraft/pylibraft/matrix/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .select_k import select_k

__all__ = ["select_k"]
Empty file.
14 changes: 14 additions & 0 deletions python/pylibraft/pylibraft/matrix/cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
39 changes: 39 additions & 0 deletions python/pylibraft/pylibraft/matrix/cpp/select_k.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3

from libc.stdint cimport int64_t
from libcpp cimport bool

from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major
from pylibraft.common.cpp.optional cimport optional
from pylibraft.common.handle cimport device_resources


cdef extern from "raft_runtime/matrix/select_k.hpp" \
namespace "raft::runtime::matrix" nogil:

cdef void select_k(const device_resources & handle,
device_matrix_view[float, int64_t, row_major],
optional[device_matrix_view[int64_t,
int64_t,
row_major]],
device_matrix_view[float, int64_t, row_major],
device_matrix_view[int64_t, int64_t, row_major],
bool) except +
Loading

0 comments on commit 574f8f8

Please sign in to comment.