Skip to content

Commit

Permalink
Comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 5, 2022
1 parent 14fad1a commit a98ba0f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
21 changes: 19 additions & 2 deletions cpp/test/linalg/mdarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ class uvector_policy {

namespace stdex = std::experimental;

/**
* @brief Modified from the c++ mdarray proposal, with the differences listed below.
*
* - Layout policy is different, the mdarray in raft uses `stdex::extent` directly just
* like `mdspan`, while the `mdarray` in the reference implementation uses varidic
* template.
*
* - Most of the constructors from the reference implementation is removed to make sure
* CUDA stream is honorred.
*/
template <class ElementType,
class Extents,
class LayoutPolicy = stdex::layout_right,
Expand All @@ -153,11 +163,14 @@ class mdarray {
using layout_type = LayoutPolicy;
using mapping_type = typename layout_type::template mapping<extents_type>;

using index_type = size_t;
using difference_type = ptrdiff_t;
using index_type = std::size_t;
using difference_type = std::ptrdiff_t;
using container_policy_type = AccessorPolicy;
using container_type = typename container_policy_type::container_type;

static_assert(!std::is_const<ElementType>::value,
"Element type for container must not be const.");

using pointer = typename container_policy_type::pointer;
using const_pointer = typename container_policy_type::const_pointer;
using reference = typename container_policy_type::reference;
Expand All @@ -166,6 +179,10 @@ class mdarray {
extents_type,
layout_type,
typename container_policy_type::accessor_policy>;
using const_view_type = stdex::mdspan<element_type const,
extents_type,
layout_type,
typename container_policy_type::accessor_policy>;

public:
constexpr mdarray() noexcept(std::is_nothrow_default_constructible<container_type>::value) =
Expand Down
17 changes: 6 additions & 11 deletions cpp/test/linalg/mdspan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,15 @@
* limitations under the License.
*/
#include "mdarray.h"
#include <experimental/mdarray>
#include <experimental/mdspan>
#include <gtest/gtest.h>
#include <raft/cudart_utils.h>
#include <rmm/device_buffer.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/device_vector.hpp>
#include <rmm/exec_policy.hpp>
#include <thrust/for_each.h>
#include <thrust/device_vector.h>
#include <thrust/iterator/counting_iterator.h>



namespace {
namespace stdex = std::experimental;
void test_mdspan()
Expand All @@ -35,17 +31,16 @@ void test_mdspan()

cudaStream_t stream = nullptr;
rmm::device_uvector<float> a{16ul, stream};
thrust::sequence(rmm::exec_policy(stream), a.begin(), a.end());

stdex::mdspan<float, stdex::extents<stdex::dynamic_extent, stdex::dynamic_extent>> span{
a.data(), 4, 4};

std::cout << "__has_cpp_attribute(no_unique_address):" << __has_cpp_attribute(no_unique_address)
<< std::endl;
thrust::for_each(it, it + 4, [=] __device__(size_t i) {
auto v = span(i, i);
printf("v: %f\n", v);
auto k = stdex::submdspan(span, stdex::full_extent, 0);
printf("k: %f\n", k(i));
auto v = span(0, i);
DEVICE_ASSERT(v == i);
auto k = stdex::submdspan(span, 0, stdex::full_extent);
DEVICE_ASSERT(k(i) == i);
});
}

Expand Down

0 comments on commit a98ba0f

Please sign in to comment.