Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using raft::resources across raft::random #1420

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions cpp/include/raft/linalg/detail/qr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include "cublas_wrappers.hpp"
#include "cusolver_wrappers.hpp"
#include <raft/core/resource/cusolver_dn_handle.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/matrix.cuh>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>
Expand All @@ -42,10 +44,10 @@ namespace detail {
*/
template <typename math_t>
void qrGetQ_inplace(
raft::device_resources const& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream)
raft::resources const& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream)
{
RAFT_EXPECTS(n_rows >= n_cols, "QR decomposition expects n_rows >= n_cols.");
cusolverDnHandle_t cusolver = handle.get_cusolver_dn_handle();
cusolverDnHandle_t cusolver = resource::get_cusolver_dn_handle(handle);

rmm::device_uvector<math_t> tau(n_cols, stream);
RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * n_cols, stream));
Expand Down Expand Up @@ -83,7 +85,7 @@ void qrGetQ_inplace(
}

template <typename math_t>
void qrGetQ(raft::device_resources const& handle,
void qrGetQ(raft::resources const& handle,
const math_t* M,
math_t* Q,
int n_rows,
Expand All @@ -95,15 +97,15 @@ void qrGetQ(raft::device_resources const& handle,
}

template <typename math_t>
void qrGetQR(raft::device_resources const& handle,
void qrGetQR(raft::resources const& handle,
math_t* M,
math_t* Q,
math_t* R,
int n_rows,
int n_cols,
cudaStream_t stream)
{
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();
cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle);

int m = n_rows, n = n_cols;
rmm::device_uvector<math_t> R_full(m * n, stream);
Expand Down
20 changes: 11 additions & 9 deletions cpp/include/raft/linalg/detail/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#include "cublas_wrappers.hpp"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <rmm/exec_policy.hpp>
#include <thrust/for_each.h>
#include <thrust/iterator/counting_iterator.h>
Expand All @@ -29,14 +31,14 @@ namespace linalg {
namespace detail {

template <typename math_t>
void transpose(raft::device_resources const& handle,
void transpose(raft::resources const& handle,
math_t* in,
math_t* out,
int n_rows,
int n_cols,
cudaStream_t stream)
{
cublasHandle_t cublas_h = handle.get_cublas_handle();
cublasHandle_t cublas_h = resource::get_cublas_handle(handle);
RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream));

int out_n_rows = n_cols;
Expand Down Expand Up @@ -83,7 +85,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream)

template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
void transpose_row_major_impl(
raft::device_resources const& handle,
raft::resources const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
Expand All @@ -92,7 +94,7 @@ void transpose_row_major_impl(
T constexpr kOne = 1;
T constexpr kZero = 0;

CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(),
CUBLAS_TRY(cublasgeam(resource::get_cublas_handle(handle),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_n_cols,
Expand All @@ -105,12 +107,12 @@ void transpose_row_major_impl(
out.stride(0),
out.data_handle(),
out.stride(0),
handle.get_stream()));
resource::get_cuda_stream(handle)));
}

template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
void transpose_col_major_impl(
raft::device_resources const& handle,
raft::resources const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
Expand All @@ -119,7 +121,7 @@ void transpose_col_major_impl(
T constexpr kOne = 1;
T constexpr kZero = 0;

CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(),
CUBLAS_TRY(cublasgeam(resource::get_cublas_handle(handle),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_n_rows,
Expand All @@ -132,7 +134,7 @@ void transpose_col_major_impl(
out.stride(1),
out.data_handle(),
out.stride(1),
handle.get_stream()));
resource::get_cuda_stream(handle)));
}
}; // end namespace detail
}; // end namespace linalg
Expand Down
19 changes: 13 additions & 6 deletions cpp/include/raft/linalg/qr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#pragma once

#include "detail/qr.cuh"
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>

namespace raft {
namespace linalg {
Expand All @@ -33,7 +35,7 @@ namespace linalg {
* @param stream cuda stream
*/
template <typename math_t>
void qrGetQ(raft::device_resources const& handle,
void qrGetQ(raft::resources const& handle,
const math_t* M,
math_t* Q,
int n_rows,
Expand All @@ -54,7 +56,7 @@ void qrGetQ(raft::device_resources const& handle,
* @param stream cuda stream
*/
template <typename math_t>
void qrGetQR(raft::device_resources const& handle,
void qrGetQR(raft::resources const& handle,
math_t* M,
math_t* Q,
math_t* R,
Expand All @@ -77,13 +79,18 @@ void qrGetQR(raft::device_resources const& handle,
* @param[out] Q Output raft::device_matrix_view
*/
template <typename ElementType, typename IndexType>
void qr_get_q(raft::device_resources const& handle,
void qr_get_q(raft::resources const& handle,
raft::device_matrix_view<const ElementType, IndexType, raft::col_major> M,
raft::device_matrix_view<ElementType, IndexType, raft::col_major> Q)
{
RAFT_EXPECTS(Q.size() == M.size(), "Size mismatch between Output and Input");

qrGetQ(handle, M.data_handle(), Q.data_handle(), M.extent(0), M.extent(1), handle.get_stream());
qrGetQ(handle,
M.data_handle(),
Q.data_handle(),
M.extent(0),
M.extent(1),
resource::get_cuda_stream(handle));
}

/**
Expand All @@ -94,7 +101,7 @@ void qr_get_q(raft::device_resources const& handle,
* @param[out] R Output raft::device_matrix_view
*/
template <typename ElementType, typename IndexType>
void qr_get_qr(raft::device_resources const& handle,
void qr_get_qr(raft::resources const& handle,
raft::device_matrix_view<const ElementType, IndexType, raft::col_major> M,
raft::device_matrix_view<ElementType, IndexType, raft::col_major> Q,
raft::device_matrix_view<ElementType, IndexType, raft::col_major> R)
Expand All @@ -107,7 +114,7 @@ void qr_get_qr(raft::device_resources const& handle,
R.data_handle(),
M.extent(0),
M.extent(1),
handle.get_stream());
resource::get_cuda_stream(handle));
}

/** @} */
Expand Down
5 changes: 3 additions & 2 deletions cpp/include/raft/linalg/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "detail/transpose.cuh"
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resources.hpp>

namespace raft {
namespace linalg {
Expand All @@ -34,7 +35,7 @@ namespace linalg {
* @param stream: cuda stream
*/
template <typename math_t>
void transpose(raft::device_resources const& handle,
void transpose(raft::resources const& handle,
math_t* in,
math_t* out,
int n_rows,
Expand Down Expand Up @@ -76,7 +77,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream)
* @param[out] out Output matirx, storage is pre-allocated by caller.
*/
template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
auto transpose(raft::device_resources const& handle,
auto transpose(raft::resources const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
-> std::enable_if_t<std::is_floating_point_v<T>, void>
Expand Down
13 changes: 6 additions & 7 deletions cpp/include/raft/random/detail/make_regression.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

#include <algorithm>

#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/init.cuh>
Expand Down Expand Up @@ -52,7 +53,7 @@ static __global__ void _singular_profile_kernel(DataT* out, IdxT n, DataT tail_s

/* Internal auxiliary function to generate a low-rank matrix */
template <typename DataT, typename IdxT>
static void _make_low_rank_matrix(raft::device_resources const& handle,
static void _make_low_rank_matrix(raft::resources const& handle,
DataT* out,
IdxT n_rows,
IdxT n_cols,
Expand All @@ -61,8 +62,7 @@ static void _make_low_rank_matrix(raft::device_resources const& handle,
raft::random::RngState& r,
cudaStream_t stream)
{
cusolverDnHandle_t cusolver_handle = handle.get_cusolver_dn_handle();
cublasHandle_t cublas_handle = handle.get_cublas_handle();
cublasHandle_t cublas_handle = resource::get_cublas_handle(handle);

IdxT n = std::min(n_rows, n_cols);

Expand Down Expand Up @@ -143,7 +143,7 @@ static __global__ void _gather2d_kernel(
}

template <typename DataT, typename IdxT>
void make_regression_caller(raft::device_resources const& handle,
void make_regression_caller(raft::resources const& handle,
DataT* out,
DataT* values,
IdxT n_rows,
Expand All @@ -162,8 +162,7 @@ void make_regression_caller(raft::device_resources const& handle,
{
n_informative = std::min(n_informative, n_cols);

cusolverDnHandle_t cusolver_handle = handle.get_cusolver_dn_handle();
cublasHandle_t cublas_handle = handle.get_cublas_handle();
cublasHandle_t cublas_handle = resource::get_cublas_handle(handle);

cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);
raft::random::RngState r(seed, type);
Expand Down
36 changes: 19 additions & 17 deletions cpp/include/raft/random/detail/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
#include <memory>
#include <optional>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cusolver_dn_handle.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/detail/cusolver_wrappers.hpp>
#include <raft/linalg/matrix_vector_op.cuh>
Expand Down Expand Up @@ -139,18 +142,16 @@ class multi_variable_gaussian_impl {
int *info, Lwork, info_h;
syevjInfo_t syevj_params = NULL;
curandGenerator_t gen;
raft::device_resources const& handle;
raft::resources const& handle;
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
bool deinitilized = false;

public: // functions
multi_variable_gaussian_impl() = delete;
multi_variable_gaussian_impl(raft::device_resources const& handle,
const int dim,
Decomposer method)
multi_variable_gaussian_impl(raft::resources const& handle, const int dim, Decomposer method)
: handle(handle), dim(dim), method(method)
{
auto cusolverHandle = handle.get_cusolver_dn_handle();
auto cusolverHandle = resource::get_cusolver_dn_handle(handle);

CURAND_CHECK(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(gen, 28)); // SEED
Expand Down Expand Up @@ -191,9 +192,9 @@ class multi_variable_gaussian_impl {

void give_gaussian(const int nPoints, T* P, T* X, const T* x = 0)
{
auto cusolverHandle = handle.get_cusolver_dn_handle();
auto cublasHandle = handle.get_cublas_handle();
auto cudaStream = handle.get_stream();
auto cusolverHandle = resource::get_cusolver_dn_handle(handle);
auto cublasHandle = resource::get_cublas_handle(handle);
auto cudaStream = resource::get_cuda_stream(handle);
if (method == chol_decomp) {
// lower part will contains chol_decomp
RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf(
Expand Down Expand Up @@ -299,7 +300,7 @@ class multi_variable_gaussian_setup_token;

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> build_multi_variable_gaussian_token_impl(
raft::device_resources const& handle,
raft::resources const& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method);
Expand All @@ -315,7 +316,7 @@ template <typename ValueType>
class multi_variable_gaussian_setup_token {
template <typename T>
friend multi_variable_gaussian_setup_token<T> build_multi_variable_gaussian_token_impl(
raft::device_resources const& handle,
raft::resources const& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method);
Expand All @@ -342,7 +343,7 @@ class multi_variable_gaussian_setup_token {

// Constructor, only for use by friend functions.
// Hiding this will let us change the implementation in the future.
multi_variable_gaussian_setup_token(raft::device_resources const& handle,
multi_variable_gaussian_setup_token(raft::resources const& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method)
Expand Down Expand Up @@ -399,22 +400,23 @@ class multi_variable_gaussian_setup_token {

private:
std::unique_ptr<multi_variable_gaussian_impl<ValueType>> impl_;
raft::device_resources const& handle_;
raft::resources const& handle_;
rmm::mr::device_memory_resource& mem_resource_;
int dim_ = 0;

auto allocate_workspace() const
{
const auto num_elements = impl_->get_workspace_size();
return rmm::device_uvector<ValueType>{num_elements, handle_.get_stream(), &mem_resource_};
return rmm::device_uvector<ValueType>{
num_elements, resource::get_cuda_stream(handle_), &mem_resource_};
}

int dim() const { return dim_; }
};

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> build_multi_variable_gaussian_token_impl(
raft::device_resources const& handle,
raft::resources const& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method)
Expand All @@ -434,7 +436,7 @@ void compute_multi_variable_gaussian_impl(

template <typename ValueType>
void compute_multi_variable_gaussian_impl(
raft::device_resources const& handle,
raft::resources const& handle,
rmm::mr::device_memory_resource& mem_resource,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
Expand All @@ -455,7 +457,7 @@ class multi_variable_gaussian : public detail::multi_variable_gaussian_impl<T> {
// using detail::multi_variable_gaussian_impl<T>::Decomposer::qr;

multi_variable_gaussian() = delete;
multi_variable_gaussian(raft::device_resources const& handle,
multi_variable_gaussian(raft::resources const& handle,
const int dim,
typename detail::multi_variable_gaussian_impl<T>::Decomposer method)
: detail::multi_variable_gaussian_impl<T>{handle, dim, method}
Expand Down
Loading