Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gpuCI] Forward-merge branch-21.10 to branch-21.12 [skip gpuci] #333

Merged
merged 5 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ci/release/update-version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ function sed_runner() {
sed -i.bak ''"$1"'' $2 && rm -f ${2}.bak
}

sed_runner 's/'"RAFT VERSION .* LANGUAGES"'/'"RAFT VERSION ${NEXT_FULL_TAG} LANGUAGES"'/g' cpp/CMakeLists.txt
sed_runner 's/'"RAFT VERSION .* LANGUAGES"'/'"RAFT VERSION ${NEXT_FULL_TAG} LANGUAGES"'/g' cpp/CMakeLists.txt
sed_runner 's/'"branch-.*\/RAPIDS.cmake"'/'"branch-${NEXT_SHORT_TAG}\/RAPIDS.cmake"'/g' cpp/CMakeLists.txt
3 changes: 3 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ endif()
##############################################################################
# - install targets-----------------------------------------------------------
rapids_cmake_install_lib_dir( lib_dir )

include(CPack)

add_library(raft INTERFACE)
add_library(raft::raft ALIAS raft)
target_include_directories(raft INTERFACE "$<BUILD_INTERFACE:${RAFT_SOURCE_DIR}/include>"
Expand Down
28 changes: 21 additions & 7 deletions cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ class handle_t {
int cur_dev = -1;
CUDA_CHECK(cudaGetDevice(&cur_dev));
return cur_dev;
}()),
streams_(n_streams) {
}()) {
if (n_streams != 0) {
streams_ = std::make_unique<rmm::cuda_stream_pool>(n_streams);
}
create_resources();
thrust_policy_ = std::make_unique<rmm::exec_policy>(user_stream_);
}
Expand All @@ -78,10 +80,13 @@ class handle_t {
*/
handle_t(const handle_t& other, int stream_id,
int n_streams = kNumDefaultWorkerStreams)
: dev_id_(other.get_device()), streams_(n_streams) {
: dev_id_(other.get_device()) {
RAFT_EXPECTS(
other.get_num_internal_streams() > 0,
"ERROR: the main handle must have at least one worker stream\n");
if (n_streams != 0) {
streams_ = std::make_unique<rmm::cuda_stream_pool>(n_streams);
}
prop_ = other.get_device_properties();
device_prop_initialized_ = true;
create_resources();
Expand Down Expand Up @@ -140,14 +145,23 @@ class handle_t {

// legacy compatibility for cuML
cudaStream_t get_internal_stream(int sid) const {
return streams_.get_stream(sid).value();
RAFT_EXPECTS(
streams_.get() != nullptr,
"ERROR: rmm::cuda_stream_pool was not initialized with a non-zero value");
return streams_->get_stream(sid).value();
}
// new accessor return rmm::cuda_stream_view
rmm::cuda_stream_view get_internal_stream_view(int sid) const {
return streams_.get_stream(sid);
RAFT_EXPECTS(
streams_.get() != nullptr,
"ERROR: rmm::cuda_stream_pool was not initialized with a non-zero value");
return streams_->get_stream(sid);
}

int get_num_internal_streams() const {
return streams_.get() != nullptr ? streams_->get_pool_size() : 0;
}

int get_num_internal_streams() const { return streams_.get_pool_size(); }
std::vector<cudaStream_t> get_internal_streams() const {
std::vector<cudaStream_t> int_streams_vec;
for (int i = 0; i < get_num_internal_streams(); i++) {
Expand Down Expand Up @@ -212,7 +226,7 @@ class handle_t {
std::unordered_map<std::string, std::shared_ptr<comms::comms_t>> subcomms_;

const int dev_id_;
rmm::cuda_stream_pool streams_{0};
std::unique_ptr<rmm::cuda_stream_pool> streams_{nullptr};
mutable cublasHandle_t cublas_handle_;
mutable bool cublas_initialized_{false};
mutable cusolverDnHandle_t cusolver_dn_handle_;
Expand Down
21 changes: 21 additions & 0 deletions cpp/include/raft/linalg/contractions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,27 @@ struct Policy4x4<double, _veclen> {
};
/** @} */

/**
* @defgroup Policy2x8 16 elements per thread Policy with k-block = 16
* @{
*/
template <typename DataT, int _veclen = 1>
struct Policy2x8 {};

template <int _veclen>
struct Policy2x8<float, _veclen> {
typedef KernelPolicy<float, _veclen, 16, 2, 8, 8, 32> Policy;
typedef ColKernelPolicy<float, _veclen, 16, 2, 8, 8, 32> ColPolicy;
};

template <int _veclen>
struct Policy2x8<double, _veclen> {
// this is not used just for keeping compiler happy.
typedef KernelPolicy<double, _veclen, 32, 1, 2, 8, 32> Policy;
typedef ColKernelPolicy<double, _veclen, 32, 1, 2, 8, 32> ColPolicy;
};
/** @} */

/**
* @brief Base class for gemm-like NT contractions
*
Expand Down
Loading