-
Notifications
You must be signed in to change notification settings - Fork 195
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0e4bab3
commit 9dd110d
Showing
7 changed files
with
333 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
function(find_and_configure_mdspan VERSION) | ||
rapids_cpm_find( | ||
mdspan ${VERSION} | ||
GLOBAL_TARGETS std::mdspan | ||
BUILD_EXPORT_SET raft-exports | ||
INSTALL_EXPORT_SET raft-exports | ||
CPM_ARGS | ||
GIT_REPOSITORY https://github.com/trivialfis/mdspan | ||
GIT_TAG 0193f075e977cc5f3c957425fd899e53d598f524 | ||
OPTIONS "MDSPAN_ENABLE_CUDA ON" | ||
"MDSPAN_CXX_STANDARD ON" | ||
) | ||
endfunction() | ||
|
||
find_and_configure_mdspan(0.2.0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <experimental/mdspan> | ||
#include <raft/cudart_utils.h> | ||
#include <rmm/cuda_stream_view.hpp> | ||
#include <rmm/device_uvector.hpp> | ||
#include <thrust/device_ptr.h> | ||
|
||
namespace raft::linalg { | ||
namespace detail { | ||
|
||
/** | ||
* @brief A simplified version of thrust::device_reference with support for CUDA stream. | ||
*/ | ||
template <typename T> | ||
class device_ref { | ||
public: | ||
using value_type = typename std::remove_cv_t<T>; | ||
using pointer = thrust::device_ptr<T>; | ||
using const_pointer = thrust::device_ptr<T const>; | ||
|
||
private: | ||
std::conditional_t<std::is_const<T>::value, const_pointer, pointer> ptr_; | ||
rmm::cuda_stream_view stream_; | ||
|
||
public: | ||
device_ref(thrust::device_ptr<T> ptr, rmm::cuda_stream_view stream) : ptr_{ptr}, stream_{stream} | ||
{ | ||
} | ||
|
||
operator value_type() const // NOLINT | ||
{ | ||
auto* raw = ptr_.get(); | ||
value_type v{}; | ||
RAFT_CUDA_TRY(cudaMemcpyAsync(&v, raw, sizeof(v), cudaMemcpyDeviceToHost, stream_.value())); | ||
return v; | ||
} | ||
auto operator=(T const& other) -> device_ref& | ||
{ | ||
auto* raw = ptr_.get(); | ||
RAFT_CUDA_TRY(cudaMemcpyAsync(raw, &other, sizeof(T), cudaMemcpyHostToDevice, stream_.value())); | ||
return *this; | ||
} | ||
}; | ||
|
||
/** | ||
* \brief A thin wrapper over rmm::device_uvector for implementing the mdarray container policy. | ||
* | ||
*/ | ||
template <typename T> | ||
class device_vector { | ||
rmm::device_uvector<T> data_; | ||
|
||
public: | ||
using value_type = T; | ||
using size_type = std::size_t; | ||
|
||
using reference = device_ref<T>; | ||
using const_reference = device_ref<T const>; | ||
|
||
using pointer = value_type*; | ||
using const_pointer = value_type const*; | ||
|
||
using iterator = pointer; | ||
using const_iterator = const_pointer; | ||
|
||
public: | ||
~device_vector() = default; | ||
device_vector(device_vector&&) noexcept = default; | ||
device_vector(device_vector const& that) : data_{that.data_, that.data_.stream()} {} | ||
|
||
/** | ||
* @brief Default ctor is deleted as it doesn't accept stream. | ||
*/ | ||
device_vector() = delete; | ||
/** | ||
* @brief Ctor that accepts a size, stream and an optional mr. | ||
*/ | ||
explicit device_vector( | ||
std::size_t size, | ||
rmm::cuda_stream_view stream, | ||
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) | ||
: data_{size, stream, mr} | ||
{ | ||
} | ||
/** | ||
* @brief Index operator that returns a proxy to the actual data. | ||
*/ | ||
template <typename Index> | ||
auto operator[](Index i) -> reference | ||
{ | ||
return device_ref<T>{thrust::device_ptr<T>{data_.data() + i}, data_.stream()}; | ||
} | ||
/** | ||
* @brief Index operator that returns a proxy to the actual data. | ||
*/ | ||
template <typename Index> | ||
auto operator[](Index i) const | ||
{ | ||
return device_ref<T const>{thrust::device_ptr<T const>{data_.data() + i}, data_.stream()}; | ||
} | ||
|
||
[[nodiscard]] auto data() -> pointer { return data_.data(); } | ||
[[nodiscard]] auto data() const -> const_pointer { return data_.data(); } | ||
}; | ||
|
||
class uvector_policy { | ||
rmm::cuda_stream_view stream_; | ||
|
||
public: | ||
using container_type = device_vector<float>; | ||
using pointer = typename container_type::pointer; | ||
using const_pointer = typename container_type::const_pointer; | ||
using reference = device_ref<float>; | ||
using const_reference = device_ref<float const>; | ||
using accessor_policy = std::experimental::default_accessor<float>; | ||
|
||
public: | ||
auto create(size_t n) -> container_type { return container_type(n, stream_); } | ||
/** | ||
* @brief A quick hack to pass the policy to `stdex::mdspan` without passing down the stream. | ||
*/ | ||
operator std::experimental::default_accessor<float>() const { return {}; } // NOLINT | ||
|
||
uvector_policy() = default; | ||
[[nodiscard]] auto access(container_type& c, size_t n) const { return c[n]; } | ||
}; | ||
} // namespace detail | ||
|
||
namespace stdex = std::experimental; | ||
|
||
/** | ||
* @brief Modified from the c++ mdarray proposal, with the differences listed below. | ||
* | ||
* - Layout policy is different, the mdarray in raft uses `stdex::extent` directly just | ||
* like `mdspan`, while the `mdarray` in the reference implementation uses varidic | ||
* template. | ||
* | ||
* - Most of the constructors from the reference implementation is removed to make sure | ||
* CUDA stream is honorred. | ||
*/ | ||
template <class ElementType, | ||
class Extents, | ||
class LayoutPolicy = stdex::layout_right, | ||
class AccessorPolicy = stdex::default_accessor<ElementType>> | ||
class mdarray { | ||
public: | ||
using element_type = ElementType; | ||
using extents_type = Extents; | ||
using layout_type = LayoutPolicy; | ||
using mapping_type = typename layout_type::template mapping<extents_type>; | ||
|
||
using index_type = std::size_t; | ||
using difference_type = std::ptrdiff_t; | ||
using container_policy_type = AccessorPolicy; | ||
using container_type = typename container_policy_type::container_type; | ||
|
||
static_assert(!std::is_const<ElementType>::value, | ||
"Element type for container must not be const."); | ||
|
||
using pointer = typename container_policy_type::pointer; | ||
using const_pointer = typename container_policy_type::const_pointer; | ||
using reference = typename container_policy_type::reference; | ||
using const_reference = typename container_policy_type::const_reference; | ||
using view_type = stdex::mdspan<element_type, | ||
extents_type, | ||
layout_type, | ||
typename container_policy_type::accessor_policy>; | ||
using const_view_type = stdex::mdspan<element_type const, | ||
extents_type, | ||
layout_type, | ||
typename container_policy_type::accessor_policy>; | ||
|
||
public: | ||
constexpr mdarray() noexcept(std::is_nothrow_default_constructible<container_type>::value) = | ||
default; | ||
constexpr mdarray(mdarray const&) noexcept( | ||
std::is_nothrow_copy_constructible<container_type>::value) = default; | ||
constexpr mdarray(mdarray&&) noexcept(std::is_nothrow_move_constructible<container_type>::value) = | ||
default; | ||
|
||
auto operator =(mdarray&&) noexcept(std::is_nothrow_move_assignable<container_type>::value) | ||
-> mdarray& = default; | ||
|
||
auto operator =(mdarray const&) noexcept(std::is_nothrow_copy_assignable<container_type>::value) | ||
-> mdarray& = default; | ||
|
||
~mdarray() noexcept(std::is_nothrow_destructible<container_type>::value) = default; | ||
|
||
constexpr mdarray(mapping_type const& m, container_policy_type const& cp) | ||
: cp_(cp), map_(m), c_(cp_.create(map_.required_span_size())) | ||
{ | ||
} | ||
|
||
auto view() noexcept { return view_type(c_.data(), map_, cp_); } | ||
|
||
auto data() noexcept -> pointer { return cp_.data(); } | ||
constexpr auto data() const noexcept -> const_pointer { return cp_.data(); } | ||
|
||
template <typename... IndexType> | ||
auto operator()(IndexType&&... indices) | ||
-> std::enable_if_t<sizeof...(IndexType) == extents_type::rank(), reference> | ||
{ | ||
return cp_.access(c_, map_(index_type(indices)...)); | ||
} | ||
|
||
private: | ||
container_policy_type cp_; | ||
mapping_type map_; | ||
container_type c_; | ||
}; | ||
} // namespace raft::linalg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include "mdarray.h" | ||
#include <experimental/mdspan> | ||
#include <gtest/gtest.h> | ||
#include <raft/cudart_utils.h> | ||
#include <rmm/device_uvector.hpp> | ||
#include <rmm/exec_policy.hpp> | ||
#include <thrust/for_each.h> | ||
#include <thrust/device_vector.h> | ||
#include <thrust/iterator/counting_iterator.h> | ||
|
||
namespace { | ||
namespace stdex = std::experimental; | ||
void test_mdspan() | ||
{ | ||
auto it = thrust::make_counting_iterator(0ul); | ||
|
||
cudaStream_t stream = nullptr; | ||
rmm::device_uvector<float> a{16ul, stream}; | ||
thrust::sequence(rmm::exec_policy(stream), a.begin(), a.end()); | ||
|
||
stdex::mdspan<float, stdex::extents<stdex::dynamic_extent, stdex::dynamic_extent>> span{ | ||
a.data(), 4, 4}; | ||
|
||
thrust::for_each(it, it + 4, [=] __device__(size_t i) { | ||
auto v = span(0, i); | ||
DEVICE_ASSERT(v == i); | ||
auto k = stdex::submdspan(span, 0, stdex::full_extent); | ||
DEVICE_ASSERT(k(i) == i); | ||
}); | ||
} | ||
|
||
} // namespace | ||
|
||
namespace raft::linalg { | ||
void test_mdarray() | ||
{ | ||
detail::uvector_policy policy; | ||
using matrix_extent = stdex::extents<stdex::dynamic_extent, stdex::dynamic_extent>; | ||
stdex::layout_right::mapping<matrix_extent> layout{matrix_extent{4, 4}}; | ||
mdarray<float, matrix_extent, stdex::layout_right, detail::uvector_policy> array{layout, policy}; | ||
|
||
std::cout << array(0, 0) << std::endl; | ||
array(0, 3) = 1; | ||
std::cout << array(0, 3) << std::endl; | ||
ASSERT_EQ(array(0, 3), 1); | ||
auto view = array.view(); | ||
auto it = thrust::make_counting_iterator(0ul); | ||
thrust::for_each( | ||
rmm::exec_policy(rmm::cuda_stream_default), it, it + 1, [view] __device__(auto i) { | ||
DEVICE_ASSERT(view(0, 3) == 1); | ||
}); | ||
|
||
thrust::device_vector<float> vec(10); | ||
vec[0] = 1.0; | ||
auto b = vec[0]; | ||
std::cout << b << std::endl; | ||
|
||
detail::device_vector<float> dvec(10, rmm::cuda_stream_default); | ||
auto a = dvec[2]; | ||
float c = a; | ||
std::cout << c << std::endl; | ||
} | ||
} // namespace raft::linalg | ||
|
||
TEST(MDSpan, Kernel) { test_mdspan(); } | ||
TEST(MDArray, Basic) { raft::linalg::test_mdarray(); } |